mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
test(agent): expand coverage and refactor test structure
- Add 42 tests for ContextBuilder (context.py: 0→42 tests) - Add 37 tests for SubagentManager lifecycle (subagent.py: 2→37 tests) - Add 42 unit tests for AutoCompact in isolation - Split monolithic test_runner.py (3313 lines) into 9 focused files: test_runner_core, test_runner_hooks, test_runner_errors, test_runner_safety, test_runner_persistence, test_runner_governance, test_runner_tool_execution, test_runner_injections, test_loop_runner_integration - Add 3 config passthrough tests (temperature/max_tokens/reasoning_effort) - Fix fragile patch.object(__init__) in test_stop_preserves_context - Create shared conftest.py with make_provider/make_loop factories Total: 934 tests passing, 0 regressions
This commit is contained in:
parent
00597fccd6
commit
99cc6ee808
93
tests/agent/conftest.py
Normal file
93
tests/agent/conftest.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Shared fixtures and helpers for agent tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
def make_provider(
|
||||
default_model: str = "test-model",
|
||||
*,
|
||||
max_tokens: int = 4096,
|
||||
spec: bool = True,
|
||||
) -> MagicMock:
|
||||
"""Create a spec-limited LLM provider mock."""
|
||||
mock_type = MagicMock(spec=LLMProvider) if spec else MagicMock()
|
||||
provider = mock_type
|
||||
provider.get_default_model.return_value = default_model
|
||||
provider.generation = SimpleNamespace(
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.1,
|
||||
reasoning_effort=None,
|
||||
)
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
return provider
|
||||
|
||||
|
||||
def make_loop(
|
||||
tmp_path: Path,
|
||||
*,
|
||||
model: str = "test-model",
|
||||
context_window_tokens: int = 128_000,
|
||||
session_ttl_minutes: int = 0,
|
||||
max_messages: int = 120,
|
||||
unified_session: bool = False,
|
||||
mcp_servers: dict | None = None,
|
||||
tools_config=None,
|
||||
model_presets: dict | None = None,
|
||||
hooks: list | None = None,
|
||||
provider: MagicMock | None = None,
|
||||
patch_deps: bool = False,
|
||||
) -> AgentLoop:
|
||||
"""Create a real AgentLoop for testing.
|
||||
|
||||
Args:
|
||||
patch_deps: If True, patch ContextBuilder/SessionManager/SubagentManager
|
||||
during construction (needed when workspace has no real files).
|
||||
"""
|
||||
bus = MessageBus()
|
||||
if provider is None:
|
||||
provider = make_provider(default_model=model)
|
||||
|
||||
kwargs = dict(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model=model,
|
||||
context_window_tokens=context_window_tokens,
|
||||
session_ttl_minutes=session_ttl_minutes,
|
||||
max_messages=max_messages,
|
||||
unified_session=unified_session,
|
||||
)
|
||||
if mcp_servers is not None:
|
||||
kwargs["mcp_servers"] = mcp_servers
|
||||
if tools_config is not None:
|
||||
kwargs["tools_config"] = tools_config
|
||||
if model_presets is not None:
|
||||
kwargs["model_presets"] = model_presets
|
||||
if hooks is not None:
|
||||
kwargs["hooks"] = hooks
|
||||
|
||||
if patch_deps:
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
return AgentLoop(**kwargs)
|
||||
return AgentLoop(**kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loop_factory(tmp_path):
|
||||
"""Fixture providing a factory for creating AgentLoop instances."""
|
||||
def _factory(**kwargs):
|
||||
return make_loop(tmp_path, **kwargs)
|
||||
return _factory
|
||||
554
tests/agent/test_autocompact_unit.py
Normal file
554
tests/agent/test_autocompact_unit.py
Normal file
@ -0,0 +1,554 @@
|
||||
"""Direct unit tests for AutoCompact class methods in isolation."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
|
||||
def _make_session(
|
||||
key: str = "cli:test",
|
||||
messages: list | None = None,
|
||||
last_consolidated: int = 0,
|
||||
updated_at: datetime | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Session:
|
||||
"""Create a Session with sensible defaults for testing."""
|
||||
session = Session(
|
||||
key=key,
|
||||
messages=messages or [],
|
||||
metadata=metadata or {},
|
||||
last_consolidated=last_consolidated,
|
||||
)
|
||||
if updated_at is not None:
|
||||
session.updated_at = updated_at
|
||||
return session
|
||||
|
||||
|
||||
def _make_autocompact(
|
||||
ttl: int = 15,
|
||||
sessions: SessionManager | None = None,
|
||||
consolidator: MagicMock | None = None,
|
||||
) -> AutoCompact:
|
||||
"""Create an AutoCompact with mock dependencies."""
|
||||
if sessions is None:
|
||||
sessions = MagicMock(spec=SessionManager)
|
||||
if consolidator is None:
|
||||
consolidator = MagicMock()
|
||||
consolidator.archive = AsyncMock(return_value="Summary.")
|
||||
return AutoCompact(
|
||||
sessions=sessions,
|
||||
consolidator=consolidator,
|
||||
session_ttl_minutes=ttl,
|
||||
)
|
||||
|
||||
|
||||
def _add_turns(session: Session, turns: int, *, prefix: str = "msg") -> None:
|
||||
"""Append simple user/assistant turns to a session."""
|
||||
for i in range(turns):
|
||||
session.add_message("user", f"{prefix} user {i}")
|
||||
session.add_message("assistant", f"{prefix} assistant {i}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# __init__
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInit:
|
||||
"""Test AutoCompact.__init__ stores constructor arguments correctly."""
|
||||
|
||||
def test_stores_ttl(self):
|
||||
"""_ttl should match session_ttl_minutes argument."""
|
||||
ac = _make_autocompact(ttl=30)
|
||||
assert ac._ttl == 30
|
||||
|
||||
def test_default_ttl_is_zero(self):
|
||||
"""Default TTL should be 0."""
|
||||
ac = _make_autocompact(ttl=0)
|
||||
assert ac._ttl == 0
|
||||
|
||||
def test_archiving_set_is_empty(self):
|
||||
"""_archiving should start as an empty set."""
|
||||
ac = _make_autocompact()
|
||||
assert ac._archiving == set()
|
||||
|
||||
def test_summaries_dict_is_empty(self):
|
||||
"""_summaries should start as an empty dict."""
|
||||
ac = _make_autocompact()
|
||||
assert ac._summaries == {}
|
||||
|
||||
def test_stores_sessions_reference(self):
|
||||
"""sessions attribute should reference the passed SessionManager."""
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
ac = _make_autocompact(sessions=mock_sm)
|
||||
assert ac.sessions is mock_sm
|
||||
|
||||
def test_stores_consolidator_reference(self):
|
||||
"""consolidator attribute should reference the passed Consolidator."""
|
||||
mock_c = MagicMock()
|
||||
ac = _make_autocompact(consolidator=mock_c)
|
||||
assert ac.consolidator is mock_c
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_expired
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsExpired:
|
||||
"""Test AutoCompact._is_expired edge cases."""
|
||||
|
||||
def test_ttl_zero_always_false(self):
|
||||
"""TTL=0 means auto-compact is disabled; always returns False."""
|
||||
ac = _make_autocompact(ttl=0)
|
||||
old = datetime.now() - timedelta(days=365)
|
||||
assert ac._is_expired(old) is False
|
||||
|
||||
def test_none_timestamp_returns_false(self):
|
||||
"""None timestamp should return False."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
assert ac._is_expired(None) is False
|
||||
|
||||
def test_empty_string_timestamp_returns_false(self):
|
||||
"""Empty string timestamp should return False (falsy)."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
assert ac._is_expired("") is False
|
||||
|
||||
def test_exactly_at_boundary_is_expired(self):
|
||||
"""Timestamp exactly at TTL boundary should be expired (>=)."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||
ts = now - timedelta(minutes=15)
|
||||
assert ac._is_expired(ts, now=now) is True
|
||||
|
||||
def test_just_under_boundary_not_expired(self):
|
||||
"""Timestamp just under TTL boundary should NOT be expired."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||
ts = now - timedelta(minutes=14, seconds=59)
|
||||
assert ac._is_expired(ts, now=now) is False
|
||||
|
||||
def test_iso_string_parses_correctly(self):
|
||||
"""ISO format string timestamp should be parsed and evaluated."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
now = datetime(2026, 1, 1, 12, 0, 0)
|
||||
ts = (now - timedelta(minutes=20)).isoformat()
|
||||
assert ac._is_expired(ts, now=now) is True
|
||||
|
||||
def test_custom_now_parameter(self):
|
||||
"""Custom 'now' parameter should override datetime.now()."""
|
||||
ac = _make_autocompact(ttl=10)
|
||||
ts = datetime(2026, 1, 1, 10, 0, 0)
|
||||
# 9 minutes later → not expired
|
||||
now_under = datetime(2026, 1, 1, 10, 9, 0)
|
||||
assert ac._is_expired(ts, now=now_under) is False
|
||||
# 10 minutes later → expired
|
||||
now_over = datetime(2026, 1, 1, 10, 10, 0)
|
||||
assert ac._is_expired(ts, now=now_over) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatSummary:
|
||||
"""Test AutoCompact._format_summary static method."""
|
||||
|
||||
def test_contains_isoformat_timestamp(self):
|
||||
"""Output should contain last_active as isoformat."""
|
||||
last_active = datetime(2026, 5, 13, 14, 30, 0)
|
||||
result = AutoCompact._format_summary("Some text", last_active)
|
||||
assert "2026-05-13T14:30:00" in result
|
||||
|
||||
def test_contains_summary_text(self):
|
||||
"""Output should contain the provided text verbatim."""
|
||||
last_active = datetime(2026, 1, 1)
|
||||
result = AutoCompact._format_summary("User discussed Python.", last_active)
|
||||
assert "User discussed Python." in result
|
||||
|
||||
def test_output_starts_with_label(self):
|
||||
"""Output should start with the standard prefix."""
|
||||
last_active = datetime(2026, 1, 1)
|
||||
result = AutoCompact._format_summary("text", last_active)
|
||||
assert result.startswith("Previous conversation summary (last active ")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _split_unconsolidated
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSplitUnconsolidated:
|
||||
"""Test AutoCompact._split_unconsolidated splitting logic."""
|
||||
|
||||
def test_empty_session_returns_both_empty(self):
|
||||
"""Empty session should return ([], [])."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session(messages=[])
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
assert archive == []
|
||||
assert kept == []
|
||||
|
||||
def test_all_messages_archivable_when_more_than_suffix(self):
|
||||
"""Session with many messages should archive a prefix and keep suffix."""
|
||||
ac = _make_autocompact()
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
assert len(archive) > 0
|
||||
assert len(kept) <= AutoCompact._RECENT_SUFFIX_MESSAGES
|
||||
|
||||
def test_fewer_messages_than_suffix_returns_empty_archive(self):
|
||||
"""Session with fewer messages than suffix should have empty archive."""
|
||||
ac = _make_autocompact()
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(3)]
|
||||
session = _make_session(messages=msgs)
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
assert archive == []
|
||||
assert len(kept) == len(msgs)
|
||||
|
||||
def test_respects_last_consolidated_offset(self):
|
||||
"""Only messages after last_consolidated should be considered."""
|
||||
ac = _make_autocompact()
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
# First 10 are already consolidated
|
||||
session = _make_session(messages=msgs, last_consolidated=10)
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
# Only the tail of 10 messages is considered for splitting
|
||||
assert all(m["content"] in [f"u{i}" for i in range(10, 20)] for m in kept)
|
||||
assert all(m["content"] in [f"u{i}" for i in range(10, 20)] for m in archive)
|
||||
|
||||
def test_retain_recent_legal_suffix_keeps_last_n(self):
|
||||
"""The kept suffix should be at most _RECENT_SUFFIX_MESSAGES long."""
|
||||
ac = _make_autocompact()
|
||||
# 20 user messages = 20 messages total, all after last_consolidated=0
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
archive, kept = ac._split_unconsolidated(session)
|
||||
assert len(kept) <= AutoCompact._RECENT_SUFFIX_MESSAGES
|
||||
assert len(archive) == len(msgs) - len(kept)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_expired
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckExpired:
|
||||
"""Test AutoCompact.check_expired scheduling logic."""
|
||||
|
||||
def test_empty_sessions_list(self):
|
||||
"""No sessions → schedule_background should never be called."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
mock_sm.list_sessions.return_value = []
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_not_called()
|
||||
|
||||
def test_expired_session_schedules_background(self):
|
||||
"""Expired session should trigger schedule_background."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||
mock_sm.list_sessions.return_value = [{"key": "cli:old", "updated_at": old_ts}]
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_called_once()
|
||||
assert "cli:old" in ac._archiving
|
||||
|
||||
def test_active_session_key_skips(self):
|
||||
"""Session in active_session_keys should be skipped."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||
mock_sm.list_sessions.return_value = [{"key": "cli:busy", "updated_at": old_ts}]
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler, active_session_keys={"cli:busy"})
|
||||
scheduler.assert_not_called()
|
||||
|
||||
def test_session_already_in_archiving_skips(self):
|
||||
"""Session already in _archiving set should be skipped."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
old_ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||
mock_sm.list_sessions.return_value = [{"key": "cli:dup", "updated_at": old_ts}]
|
||||
ac.sessions = mock_sm
|
||||
ac._archiving.add("cli:dup")
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_not_called()
|
||||
|
||||
def test_session_with_no_key_skips(self):
|
||||
"""Session info with empty/missing key should be skipped."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
mock_sm.list_sessions.return_value = [{"key": "", "updated_at": "old"}]
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_not_called()
|
||||
|
||||
def test_session_with_missing_key_field_skips(self):
|
||||
"""Session info dict without 'key' field should be skipped."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
mock_sm.list_sessions.return_value = [{"updated_at": "old"}]
|
||||
ac.sessions = mock_sm
|
||||
scheduler = MagicMock()
|
||||
ac.check_expired(scheduler)
|
||||
scheduler.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _archive
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestArchive:
|
||||
"""Test AutoCompact._archive async method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_session_updates_timestamp_no_archive_call(self):
|
||||
"""Empty session should refresh updated_at and not call consolidator.archive."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
empty_session = _make_session(messages=[])
|
||||
mock_sm.get_or_create.return_value = empty_session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="Summary.")
|
||||
|
||||
await ac._archive("cli:test")
|
||||
|
||||
ac.consolidator.archive.assert_not_called()
|
||||
mock_sm.save.assert_called_once_with(empty_session)
|
||||
# updated_at was refreshed
|
||||
assert empty_session.updated_at > datetime.now() - timedelta(seconds=5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_returns_empty_string_no_summary_stored(self):
|
||||
"""If archive returns empty string, no summary should be stored."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="")
|
||||
|
||||
await ac._archive("cli:test")
|
||||
|
||||
assert "cli:test" not in ac._summaries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_returns_nothing_no_summary_stored(self):
|
||||
"""If archive returns '(nothing)', no summary should be stored."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="(nothing)")
|
||||
|
||||
await ac._archive("cli:test")
|
||||
|
||||
assert "cli:test" not in ac._summaries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_exception_caught_key_removed_from_archiving(self):
|
||||
"""If archive raises, exception is caught and key removed from _archiving."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
# Should not raise
|
||||
await ac._archive("cli:test")
|
||||
|
||||
assert "cli:test" not in ac._archiving
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_archive_stores_summary_in_summaries_and_metadata(self):
|
||||
"""Successful archive should store summary in _summaries dict and metadata."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
last_active = datetime(2026, 5, 13, 10, 0, 0)
|
||||
session = _make_session(messages=msgs, updated_at=last_active)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="User discussed AI.")
|
||||
|
||||
await ac._archive("cli:test")
|
||||
|
||||
# _summaries
|
||||
entry = ac._summaries.get("cli:test")
|
||||
assert entry is not None
|
||||
assert entry[0] == "User discussed AI."
|
||||
assert entry[1] == last_active
|
||||
# metadata
|
||||
meta = session.metadata.get("_last_summary")
|
||||
assert meta is not None
|
||||
assert meta["text"] == "User discussed AI."
|
||||
assert "last_active" in meta
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finally_block_always_removes_from_archiving(self):
|
||||
"""Finally block should always remove key from _archiving, even on error."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(side_effect=RuntimeError("fail"))
|
||||
|
||||
# Pre-add key to archiving to verify it gets removed
|
||||
ac._archiving.add("cli:test")
|
||||
await ac._archive("cli:test")
|
||||
assert "cli:test" not in ac._archiving
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finally_removes_from_archiving_on_success(self):
|
||||
"""Finally block should remove key from _archiving on success too."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
msgs = [{"role": "user", "content": f"u{i}"} for i in range(20)]
|
||||
session = _make_session(messages=msgs)
|
||||
mock_sm.get_or_create.return_value = session
|
||||
ac.sessions = mock_sm
|
||||
ac.consolidator.archive = AsyncMock(return_value="Summary.")
|
||||
|
||||
ac._archiving.add("cli:test")
|
||||
await ac._archive("cli:test")
|
||||
assert "cli:test" not in ac._archiving
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prepare_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPrepareSession:
|
||||
"""Test AutoCompact.prepare_session logic."""
|
||||
|
||||
def test_key_in_archiving_reloads_session(self):
|
||||
"""If key is in _archiving, session should be reloaded via get_or_create."""
|
||||
ac = _make_autocompact()
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
reloaded = _make_session(key="cli:test")
|
||||
mock_sm.get_or_create.return_value = reloaded
|
||||
ac.sessions = mock_sm
|
||||
ac._archiving.add("cli:test")
|
||||
|
||||
original_session = _make_session()
|
||||
result_session, summary = ac.prepare_session(original_session, "cli:test")
|
||||
|
||||
mock_sm.get_or_create.assert_called_once_with("cli:test")
|
||||
assert result_session is reloaded
|
||||
|
||||
def test_expired_session_reloads(self):
|
||||
"""If session is expired, it should be reloaded via get_or_create."""
|
||||
ac = _make_autocompact(ttl=15)
|
||||
mock_sm = MagicMock(spec=SessionManager)
|
||||
reloaded = _make_session(key="cli:test", updated_at=datetime.now())
|
||||
mock_sm.get_or_create.return_value = reloaded
|
||||
ac.sessions = mock_sm
|
||||
|
||||
old_session = _make_session(updated_at=datetime.now() - timedelta(minutes=20))
|
||||
result_session, summary = ac.prepare_session(old_session, "cli:test")
|
||||
|
||||
mock_sm.get_or_create.assert_called_once_with("cli:test")
|
||||
assert result_session is reloaded
|
||||
|
||||
def test_hot_path_summary_from_summaries(self):
|
||||
"""Summary from _summaries dict should be returned (hot path)."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session()
|
||||
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||
ac._summaries["cli:test"] = ("Hot summary.", last_active)
|
||||
|
||||
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||
|
||||
assert result_session is session
|
||||
assert summary is not None
|
||||
assert "Hot summary." in summary
|
||||
assert "Previous conversation summary" in summary
|
||||
|
||||
def test_hot_path_pops_summary_one_shot(self):
|
||||
"""Hot path should pop the summary (one-shot; second call returns None)."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session()
|
||||
last_active = datetime(2026, 1, 1)
|
||||
ac._summaries["cli:test"] = ("One-shot.", last_active)
|
||||
|
||||
_, summary1 = ac.prepare_session(session, "cli:test")
|
||||
assert summary1 is not None
|
||||
# Second call: hot path entry was popped
|
||||
_, summary2 = ac.prepare_session(session, "cli:test")
|
||||
assert summary2 is None
|
||||
|
||||
def test_cold_path_summary_from_metadata(self):
|
||||
"""When _summaries is empty, summary should come from metadata (cold path)."""
|
||||
ac = _make_autocompact()
|
||||
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||
session = _make_session(metadata={
|
||||
"_last_summary": {
|
||||
"text": "Cold summary.",
|
||||
"last_active": last_active.isoformat(),
|
||||
},
|
||||
})
|
||||
|
||||
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||
|
||||
assert result_session is session
|
||||
assert summary is not None
|
||||
assert "Cold summary." in summary
|
||||
|
||||
def test_no_summary_available_returns_none(self):
|
||||
"""When no summary is available, should return (session, None)."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session()
|
||||
|
||||
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||
|
||||
assert result_session is session
|
||||
assert summary is None
|
||||
|
||||
def test_cold_path_metadata_not_dict_returns_none(self):
|
||||
"""If metadata _last_summary is not a dict, should return None summary."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session(metadata={"_last_summary": "not a dict"})
|
||||
|
||||
result_session, summary = ac.prepare_session(session, "cli:test")
|
||||
|
||||
assert result_session is session
|
||||
assert summary is None
|
||||
|
||||
def test_hot_path_takes_priority_over_metadata(self):
|
||||
"""Hot path (_summaries) should take priority over metadata."""
|
||||
ac = _make_autocompact()
|
||||
session = _make_session(metadata={
|
||||
"_last_summary": {
|
||||
"text": "Cold summary.",
|
||||
"last_active": datetime(2026, 1, 1).isoformat(),
|
||||
},
|
||||
})
|
||||
last_active = datetime(2026, 5, 13, 14, 0, 0)
|
||||
ac._summaries["cli:test"] = ("Hot summary.", last_active)
|
||||
|
||||
_, summary = ac.prepare_session(session, "cli:test")
|
||||
assert "Hot summary." in summary
|
||||
# After hot path pops, cold path would kick in on next call
|
||||
333
tests/agent/test_context_builder.py
Normal file
333
tests/agent/test_context_builder.py
Normal file
@ -0,0 +1,333 @@
|
||||
"""Tests for ContextBuilder — system prompt and message assembly."""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _builder(tmp_path: Path, **kw) -> ContextBuilder:
|
||||
return ContextBuilder(workspace=tmp_path, **kw)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_runtime_context (static)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildRuntimeContext:
|
||||
def test_time_only(self):
|
||||
ctx = ContextBuilder._build_runtime_context(None, None)
|
||||
assert "[Runtime Context" in ctx
|
||||
assert "[/Runtime Context]" in ctx
|
||||
assert "Current Time:" in ctx
|
||||
assert "Channel:" not in ctx
|
||||
|
||||
def test_with_channel_and_chat_id(self):
|
||||
ctx = ContextBuilder._build_runtime_context("telegram", "chat123")
|
||||
assert "Channel: telegram" in ctx
|
||||
assert "Chat ID: chat123" in ctx
|
||||
|
||||
def test_with_sender_id(self):
|
||||
ctx = ContextBuilder._build_runtime_context("cli", "direct", sender_id="user1")
|
||||
assert "Sender ID: user1" in ctx
|
||||
|
||||
def test_with_timezone(self):
|
||||
ctx = ContextBuilder._build_runtime_context(None, None, timezone="Asia/Shanghai")
|
||||
assert "Current Time:" in ctx
|
||||
|
||||
def test_no_channel_no_chat_id_omits_both(self):
|
||||
ctx = ContextBuilder._build_runtime_context(None, None)
|
||||
assert "Channel:" not in ctx
|
||||
assert "Chat ID:" not in ctx
|
||||
|
||||
def test_no_sender_id_omits(self):
|
||||
ctx = ContextBuilder._build_runtime_context("cli", "direct")
|
||||
assert "Sender ID:" not in ctx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _merge_message_content (static)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeMessageContent:
|
||||
def test_str_plus_str(self):
|
||||
result = ContextBuilder._merge_message_content("hello", "world")
|
||||
assert result == "hello\n\nworld"
|
||||
|
||||
def test_empty_left_plus_str(self):
|
||||
result = ContextBuilder._merge_message_content("", "world")
|
||||
assert result == "world"
|
||||
|
||||
def test_list_plus_list(self):
|
||||
left = [{"type": "text", "text": "a"}]
|
||||
right = [{"type": "text", "text": "b"}]
|
||||
result = ContextBuilder._merge_message_content(left, right)
|
||||
assert len(result) == 2
|
||||
assert result[0]["text"] == "a"
|
||||
assert result[1]["text"] == "b"
|
||||
|
||||
def test_str_plus_list(self):
|
||||
right = [{"type": "text", "text": "b"}]
|
||||
result = ContextBuilder._merge_message_content("hello", right)
|
||||
assert len(result) == 2
|
||||
assert result[0]["text"] == "hello"
|
||||
assert result[1]["text"] == "b"
|
||||
|
||||
def test_list_plus_str(self):
|
||||
left = [{"type": "text", "text": "a"}]
|
||||
result = ContextBuilder._merge_message_content(left, "world")
|
||||
assert len(result) == 2
|
||||
assert result[0]["text"] == "a"
|
||||
assert result[1]["text"] == "world"
|
||||
|
||||
def test_none_plus_str(self):
|
||||
result = ContextBuilder._merge_message_content(None, "hello")
|
||||
assert result == [{"type": "text", "text": "hello"}]
|
||||
|
||||
def test_str_plus_none(self):
|
||||
result = ContextBuilder._merge_message_content("hello", None)
|
||||
assert result == [{"type": "text", "text": "hello"}]
|
||||
|
||||
def test_none_plus_none(self):
|
||||
result = ContextBuilder._merge_message_content(None, None)
|
||||
assert result == []
|
||||
|
||||
def test_list_items_not_dicts_wrapped(self):
|
||||
result = ContextBuilder._merge_message_content(["raw_item"], None)
|
||||
assert result == [{"type": "text", "text": "raw_item"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _load_bootstrap_files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadBootstrapFiles:
|
||||
def test_no_bootstrap_files(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
assert builder._load_bootstrap_files() == ""
|
||||
|
||||
def test_agents_md(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("Be helpful.", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
assert "## AGENTS.md" in result
|
||||
assert "Be helpful." in result
|
||||
|
||||
def test_multiple_bootstrap_files(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("Rules.", encoding="utf-8")
|
||||
(tmp_path / "SOUL.md").write_text("Soul.", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
assert "## AGENTS.md" in result
|
||||
assert "## SOUL.md" in result
|
||||
assert "Rules." in result
|
||||
assert "Soul." in result
|
||||
|
||||
def test_all_bootstrap_files(self, tmp_path):
|
||||
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||
(tmp_path / name).write_text(f"Content of {name}", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||
assert f"## {name}" in result
|
||||
|
||||
def test_utf8_content(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
assert "用中文回复" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_template_content (static)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsTemplateContent:
|
||||
def test_nonexistent_template_returns_false(self):
|
||||
assert ContextBuilder._is_template_content("anything", "nonexistent/path.md") is False
|
||||
|
||||
def test_content_matching_template(self):
|
||||
from importlib.resources import files as pkg_files
|
||||
tpl = pkg_files("nanobot") / "templates" / "memory" / "MEMORY.md"
|
||||
if not tpl.is_file():
|
||||
pytest.skip("MEMORY.md template not bundled")
|
||||
original = tpl.read_text(encoding="utf-8")
|
||||
assert ContextBuilder._is_template_content(original, "memory/MEMORY.md") is True
|
||||
|
||||
def test_modified_content_returns_false(self):
|
||||
from importlib.resources import files as pkg_files
|
||||
tpl = pkg_files("nanobot") / "templates" / "memory" / "MEMORY.md"
|
||||
if not tpl.is_file():
|
||||
pytest.skip("MEMORY.md template not bundled")
|
||||
assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_user_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildUserContent:
|
||||
def test_no_media_returns_string(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", None)
|
||||
assert result == "hello"
|
||||
|
||||
def test_empty_media_returns_string(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", [])
|
||||
assert result == "hello"
|
||||
|
||||
def test_nonexistent_media_file_returns_string(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", ["/nonexistent/image.png"])
|
||||
assert result == "hello"
|
||||
|
||||
def test_non_image_file_returns_string(self, tmp_path):
|
||||
txt = tmp_path / "doc.txt"
|
||||
txt.write_text("not an image", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", [str(txt)])
|
||||
assert result == "hello"
|
||||
|
||||
def test_valid_image_returns_list(self, tmp_path):
|
||||
png = tmp_path / "test.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", [str(png)])
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert result[0]["type"] == "image_url"
|
||||
assert result[0]["image_url"]["url"].startswith("data:image/png;base64,")
|
||||
assert result[1]["type"] == "text"
|
||||
assert result[1]["text"] == "hello"
|
||||
|
||||
def test_image_meta_includes_path(self, tmp_path):
|
||||
png = tmp_path / "test.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._build_user_content("hello", [str(png)])
|
||||
assert "_meta" in result[0]
|
||||
assert "path" in result[0]["_meta"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_system_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSystemPrompt:
|
||||
def test_returns_nonempty_string(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt()
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
def test_includes_identity_section(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt()
|
||||
assert "workspace" in result.lower() or "python" in result.lower()
|
||||
|
||||
def test_includes_bootstrap_files(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("Be helpful and concise.", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt()
|
||||
assert "Be helpful and concise." in result
|
||||
|
||||
def test_includes_session_summary(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt(session_summary="Previous chat about Python.")
|
||||
assert "Previous chat about Python." in result
|
||||
assert "[Archived Context Summary]" in result
|
||||
|
||||
def test_sections_separated_by_separator(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("Rules.", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt(session_summary="Summary.")
|
||||
assert "\n\n---\n\n" in result
|
||||
|
||||
def test_no_bootstrap_no_summary(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
result = builder.build_system_prompt()
|
||||
assert "## AGENTS.md" not in result
|
||||
assert "[Archived Context Summary]" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildMessages:
|
||||
def test_basic_empty_history(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
messages = builder.build_messages([], "hello")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "hello" in str(messages[1]["content"])
|
||||
|
||||
def test_runtime_context_injected(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
messages = builder.build_messages([], "hello", channel="cli", chat_id="direct")
|
||||
user_msg = str(messages[-1]["content"])
|
||||
assert "[Runtime Context" in user_msg
|
||||
assert "hello" in user_msg
|
||||
|
||||
def test_consecutive_same_role_merged(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
history = [{"role": "user", "content": "previous user message"}]
|
||||
messages = builder.build_messages(history, "new message")
|
||||
assert len(messages) == 2 # system + merged user
|
||||
assert "previous user message" in str(messages[1]["content"])
|
||||
assert "new message" in str(messages[1]["content"])
|
||||
|
||||
def test_different_role_appended(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
history = [{"role": "assistant", "content": "previous response"}]
|
||||
messages = builder.build_messages(history, "new message")
|
||||
assert len(messages) == 3 # system + assistant + user
|
||||
|
||||
def test_media_with_history(self, tmp_path):
|
||||
png = tmp_path / "img.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 16)
|
||||
builder = _builder(tmp_path)
|
||||
history = [{"role": "assistant", "content": "see this"}]
|
||||
messages = builder.build_messages(history, "check image", media=[str(png)])
|
||||
user_msg = messages[-1]["content"]
|
||||
assert isinstance(user_msg, list)
|
||||
assert any(b.get("type") == "image_url" for b in user_msg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_tool_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAddToolResult:
|
||||
def test_appends_tool_message(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
result = builder.add_tool_result(msgs, "call_123", "read_file", "file content")
|
||||
assert len(result) == 2
|
||||
assert result[1]["role"] == "tool"
|
||||
assert result[1]["tool_call_id"] == "call_123"
|
||||
assert result[1]["name"] == "read_file"
|
||||
assert result[1]["content"] == "file content"
|
||||
|
||||
def test_returns_same_list(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
msgs = []
|
||||
result = builder.add_tool_result(msgs, "id", "tool", "ok")
|
||||
assert result is msgs
|
||||
301
tests/agent/test_loop_runner_integration.py
Normal file
301
tests/agent/test_loop_runner_integration.py
Normal file
@ -0,0 +1,301 @@
|
||||
"""Tests for AgentLoop integration with AgentRunner: streaming, think-filter, error handling, subagent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||
return loop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
deltas: list[str] = []
|
||||
endings: list[bool] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("<think>hidden")
|
||||
await on_content_delta("</think>Hello")
|
||||
return LLMResponse(content="<think>hidden</think>Hello", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[],
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
assert final_content == "Hello"
|
||||
assert deltas == ["Hello"]
|
||||
assert endings == [False]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_hides_partial_trailing_think_prefix(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
deltas: list[str] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("Hello <thin")
|
||||
await on_content_delta("k>hidden</think>World")
|
||||
return LLMResponse(content="Hello <think>hidden</think>World", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream)
|
||||
|
||||
assert final_content == "Hello World"
|
||||
assert deltas == ["Hello", " World"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_hides_complete_trailing_think_tag(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
deltas: list[str] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("Hello <think>")
|
||||
await on_content_delta("hidden</think>World")
|
||||
return LLMResponse(content="Hello <think>hidden</think>World", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_stream=on_stream)
|
||||
|
||||
assert final_content == "Hello World"
|
||||
assert deltas == ["Hello", " World"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_retries_think_only_final_response(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="<think>hidden</think>", tool_calls=[], usage={})
|
||||
return LLMResponse(content="Recovered answer", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_with_retry = chat_with_retry
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == "Recovered answer"
|
||||
assert call_count["n"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streamed_flag_not_set_on_llm_error(tmp_path):
|
||||
"""When LLM errors during a streaming-capable channel interaction,
|
||||
_streamed must NOT be set so ChannelManager delivers the error."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
error_resp = LLMResponse(
|
||||
content="503 service unavailable", finish_reason="error", tool_calls=[], usage={},
|
||||
)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=error_resp)
|
||||
loop.provider.chat_stream_with_retry = AsyncMock(return_value=error_resp)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="feishu", sender_id="u1", chat_id="c1", content="hi",
|
||||
)
|
||||
result = await loop._process_message(
|
||||
msg,
|
||||
on_stream=AsyncMock(),
|
||||
on_stream_end=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "503" in result.content
|
||||
assert not result.metadata.get("_streamed"), \
|
||||
"_streamed must not be set when stop_reason is error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ssrf_soft_block_can_finalize_after_streamed_tool_call(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
tool_call_resp = LLMResponse(
|
||||
content="checking metadata",
|
||||
tool_calls=[ToolCallRequest(
|
||||
id="call_ssrf",
|
||||
name="exec",
|
||||
arguments={"command": "curl http://169.254.169.254/latest/meta-data/"},
|
||||
)],
|
||||
usage={},
|
||||
)
|
||||
provider.chat_stream_with_retry = AsyncMock(side_effect=[
|
||||
tool_call_resp,
|
||||
LLMResponse(
|
||||
content="I cannot access private URLs. Please share the local file.",
|
||||
tool_calls=[],
|
||||
usage={},
|
||||
),
|
||||
])
|
||||
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.prepare_call = MagicMock(return_value=(None, {}, None))
|
||||
loop.tools.execute = AsyncMock(return_value=(
|
||||
"Error: Command blocked by safety guard (internal/private URL detected)"
|
||||
))
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="hi"),
|
||||
on_stream=AsyncMock(),
|
||||
on_stream_end=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "I cannot access private URLs. Please share the local file."
|
||||
assert result.metadata.get("_streamed") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={}),
|
||||
LLMResponse(content="Recovered answer", tool_calls=[], usage={}),
|
||||
])
|
||||
|
||||
loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="first question")
|
||||
)
|
||||
assert first is not None
|
||||
assert first.content == "429 rate limit exceeded"
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
assert [
|
||||
{key: value for key, value in message.items() if key in {"role", "content"}}
|
||||
for message in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "first question"},
|
||||
{"role": "assistant", "content": _PERSISTED_MODEL_ERROR_PLACEHOLDER},
|
||||
]
|
||||
|
||||
second = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second question")
|
||||
)
|
||||
assert second is not None
|
||||
assert second.content == "Recovered answer"
|
||||
|
||||
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
|
||||
non_system = [message for message in request_messages if message.get("role") != "system"]
|
||||
assert non_system[0]["role"] == "user"
|
||||
assert "first question" in non_system[0]["content"]
|
||||
assert non_system[1]["role"] == "assistant"
|
||||
assert _PERSISTED_MODEL_ERROR_PLACEHOLDER in non_system[1]["content"]
|
||||
assert non_system[2]["role"] == "user"
|
||||
assert "second question" in non_system[2]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||
from nanobot.agent.subagent import SubagentManager, SubagentStatus
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
async def fake_execute(self, **kwargs):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
status = SubagentStatus(task_id="sub-1", label="label", task_description="do task", started_at=time.monotonic())
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}, status)
|
||||
|
||||
mgr._announce_result.assert_awaited_once()
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert args[3] == "Task completed but no final response was generated."
|
||||
assert args[5] == "ok"
|
||||
File diff suppressed because it is too large
Load Diff
481
tests/agent/test_runner_core.py
Normal file
481
tests/agent/test_runner_core.py
Normal file
@ -0,0 +1,481 @@
|
||||
"""Tests for core AgentRunner behavior: message passing, iteration limits,
|
||||
timeouts, empty-response handling, usage accumulation, and config passthrough."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "do task"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert result.tools_used == ["list_dir"]
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "ok", "detail": "tool result"}
|
||||
]
|
||||
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
assert any(
|
||||
msg.get("role") == "tool" and msg.get("content") == "tool result"
|
||||
for msg in captured_second_call
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_max_iterations_fallback():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="still working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
assert result.final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
assert result.messages[-1]["content"] == result.final_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_times_out_hung_llm_request():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
started = time.monotonic()
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
llm_timeout_s=0.05,
|
||||
))
|
||||
|
||||
assert (time.monotonic() - started) < 1.0
|
||||
assert result.stop_reason == "error"
|
||||
assert "timed out" in (result.final_content or "").lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_replaces_empty_tool_result_with_marker():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})],
|
||||
usage={},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert tool_message["content"] == "(noop completed with no output)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_retries_empty_final_response_with_summary_prompt():
|
||||
"""Empty responses get 2 silent retries before finalization kicks in."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
calls: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||
calls.append({"messages": messages, "tools": tools})
|
||||
if len(calls) <= 2:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 1},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="final answer",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 3, "completion_tokens": 7},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "final answer"
|
||||
# 2 silent retries (iterations 0,1) + finalization on iteration 1
|
||||
assert len(calls) == 3
|
||||
assert calls[0]["tools"] is not None
|
||||
assert calls[1]["tools"] is not None
|
||||
assert calls[2]["tools"] is None
|
||||
assert result.usage["prompt_tokens"] == 13
|
||||
assert result.usage["completion_tokens"] == 9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_uses_specific_message_after_empty_finalization_retry():
|
||||
"""After silent retries + finalization all return empty, stop_reason is empty_final_response."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
return LLMResponse(content=None, tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
assert result.stop_reason == "empty_final_response"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_empty_response_does_not_break_tool_chain():
|
||||
"""An empty intermediate response must not kill an ongoing tool chain.
|
||||
|
||||
Sequence: tool_call -> empty -> tool_call -> final text.
|
||||
The runner should recover via silent retry and complete normally.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
call_count = 0
|
||||
|
||||
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a.txt"})],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
)
|
||||
if call_count == 2:
|
||||
return LLMResponse(content=None, tool_calls=[], usage={"prompt_tokens": 10, "completion_tokens": 1})
|
||||
if call_count == 3:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[ToolCallRequest(id="tc2", name="read_file", arguments={"path": "b.txt"})],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="Here are the results.",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 10},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_with_retry
|
||||
|
||||
async def fake_tool(name, args, **kw):
|
||||
return "file content"
|
||||
|
||||
tool_registry = MagicMock()
|
||||
tool_registry.get_definitions.return_value = [{"type": "function", "function": {"name": "read_file"}}]
|
||||
tool_registry.execute = AsyncMock(side_effect=fake_tool)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "read both files"}],
|
||||
tools=tool_registry,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "Here are the results."
|
||||
assert result.stop_reason == "completed"
|
||||
assert call_count == 4
|
||||
assert "read_file" in result.tools_used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||
"""Runner should accumulate prompt/completion tokens across iterations
|
||||
and preserve cached_tokens from provider responses."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
# Usage should be accumulated across iterations
|
||||
assert result.usage["prompt_tokens"] == 300 # 100 + 200
|
||||
assert result.usage["completion_tokens"] == 30 # 10 + 20
|
||||
assert result.usage["cached_tokens"] == 230 # 80 + 150
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_binds_on_retry_wait_to_retry_callback_not_progress():
|
||||
"""Regression: provider retry heartbeats must route through
|
||||
``retry_wait_callback``, not ``progress_callback``. Binding them to
|
||||
the progress callback (as an earlier runtime refactor did) caused
|
||||
internal retry diagnostics like "Model request failed, retry in 1s"
|
||||
to leak to end-user channels as normal progress updates.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
progress_cb = AsyncMock()
|
||||
retry_wait_cb = AsyncMock()
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
progress_callback=progress_cb,
|
||||
retry_wait_callback=retry_wait_cb,
|
||||
))
|
||||
|
||||
assert captured["on_retry_wait"] is retry_wait_cb
|
||||
assert captured["on_retry_wait"] is not progress_cb
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config passthrough tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_temperature_to_provider():
|
||||
"""temperature from AgentRunSpec should reach provider.chat_with_retry."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
temperature=0.7,
|
||||
))
|
||||
|
||||
assert captured["temperature"] == 0.7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_max_tokens_to_provider():
|
||||
"""max_tokens from AgentRunSpec should reach provider.chat_with_retry."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
max_tokens=8192,
|
||||
))
|
||||
|
||||
assert captured["max_tokens"] == 8192
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_reasoning_effort_to_provider():
|
||||
"""reasoning_effort from AgentRunSpec should reach provider.chat_with_retry."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
reasoning_effort="high",
|
||||
))
|
||||
|
||||
assert captured["reasoning_effort"] == "high"
|
||||
171
tests/agent/test_runner_errors.py
Normal file
171
tests/agent/test_runner_errors.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""Tests for AgentRunner error handling: tool errors, LLM errors,
|
||||
session message isolation, and tool result preservation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_structured_tool_error():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "tool_error"
|
||||
assert result.error == "Error: RuntimeError: boom"
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "error", "detail": "boom"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_error_not_appended_to_session_messages():
|
||||
"""When LLM returns finish_reason='error', the error content must NOT be
|
||||
appended to the messages list (prevents polluting session history)."""
|
||||
from nanobot.agent.runner import (
|
||||
AgentRunSpec,
|
||||
AgentRunner,
|
||||
_PERSISTED_MODEL_ERROR_PLACEHOLDER,
|
||||
)
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="429 rate limit exceeded", finish_reason="error", tool_calls=[], usage={},
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "error"
|
||||
assert result.final_content == "429 rate limit exceeded"
|
||||
assistant_msgs = [m for m in result.messages if m.get("role") == "assistant"]
|
||||
assert all("429" not in (m.get("content") or "") for m in assistant_msgs), \
|
||||
"Error content should not appear in session messages"
|
||||
assert assistant_msgs[-1]["content"] == _PERSISTED_MODEL_ERROR_PLACEHOLDER
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_tool_error_sets_final_content():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.final_content == "Error: RuntimeError: boom"
|
||||
assert result.stop_reason == "tool_error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_tool_error_preserves_tool_results_in_messages():
|
||||
"""When a tool raises a fatal error, its results must still be appended
|
||||
to messages so the session never contains orphan tool_calls (#2943)."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a"}),
|
||||
ToolCallRequest(id="tc2", name="exec", arguments={"cmd": "bad"}),
|
||||
],
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_with_retry
|
||||
|
||||
call_idx = 0
|
||||
|
||||
async def fake_execute(name, args, **kw):
|
||||
nonlocal call_idx
|
||||
call_idx += 1
|
||||
if call_idx == 2:
|
||||
raise RuntimeError("boom")
|
||||
return "file content"
|
||||
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=fake_execute)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do stuff"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "tool_error"
|
||||
# Both tool results must be in messages even though tc2 had a fatal error.
|
||||
tool_msgs = [m for m in result.messages if m.get("role") == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
assert tool_msgs[0]["tool_call_id"] == "tc1"
|
||||
assert tool_msgs[1]["tool_call_id"] == "tc2"
|
||||
# The assistant message with tool_calls must precede the tool results.
|
||||
asst_tc_idx = next(
|
||||
i for i, m in enumerate(result.messages)
|
||||
if m.get("role") == "assistant" and m.get("tool_calls")
|
||||
)
|
||||
tool_indices = [
|
||||
i for i, m in enumerate(result.messages) if m.get("role") == "tool"
|
||||
]
|
||||
assert all(ti > asst_tc_idx for ti in tool_indices)
|
||||
643
tests/agent/test_runner_governance.py
Normal file
643
tests/agent/test_runner_governance.py
Normal file
@ -0,0 +1,643 @@
|
||||
"""Tests for AgentRunner context governance: backfill, orphan cleanup, microcompact, snip_history."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||
return loop
|
||||
|
||||
async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
initial_messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert captured_messages == initial_messages
|
||||
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
runner = AgentRunner(provider)
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "tool call",
|
||||
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "tool output"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=100,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None))
|
||||
token_sizes = {
|
||||
"old user": 120,
|
||||
"tool call": 120,
|
||||
"tool output": 40,
|
||||
"after tool": 40,
|
||||
"system": 0,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: token_sizes.get(str(msg.get("content")), 40),
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
# After the fix, the user message is recovered so the sequence is valid
|
||||
# for providers that require system → user (e.g. GLM error 1214).
|
||||
assert trimmed[0]["role"] == "system"
|
||||
non_system = [m for m in trimmed if m["role"] != "system"]
|
||||
assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}"
|
||||
async def test_backfill_missing_tool_results_inserts_error():
|
||||
"""Orphaned tool_use (no matching tool_result) should get a synthetic error."""
|
||||
from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_a", "type": "function", "function": {"name": "exec", "arguments": "{}"}},
|
||||
{"id": "call_b", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_a", "name": "exec", "content": "ok"},
|
||||
]
|
||||
result = AgentRunner._backfill_missing_tool_results(messages)
|
||||
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
backfilled = [m for m in tool_msgs if m.get("tool_call_id") == "call_b"]
|
||||
assert len(backfilled) == 1
|
||||
assert backfilled[0]["content"] == _BACKFILL_CONTENT
|
||||
assert backfilled[0]["name"] == "read_file"
|
||||
|
||||
|
||||
def test_drop_orphan_tool_results_removes_unmatched_tool_messages():
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
|
||||
cleaned = AgentRunner._drop_orphan_tool_results(messages)
|
||||
|
||||
assert cleaned == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_ok", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_ok", "name": "read_file", "content": "ok"},
|
||||
{"role": "assistant", "content": "after tool"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_noop_when_complete():
|
||||
"""Complete message chains should not be modified."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_x", "type": "function", "function": {"name": "exec", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_x", "name": "exec", "content": "done"},
|
||||
{"role": "assistant", "content": "all good"},
|
||||
]
|
||||
result = AgentRunner._backfill_missing_tool_results(messages)
|
||||
assert result is messages # same object — no copy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_drops_orphan_tool_results_before_model_request():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "name": "exec", "content": "stale"},
|
||||
{"role": "assistant", "content": "after orphan"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert all(
|
||||
message.get("tool_call_id") != "call_orphan"
|
||||
for message in captured_messages
|
||||
if message.get("role") == "tool"
|
||||
)
|
||||
assert result.messages[2]["tool_call_id"] == "call_orphan"
|
||||
assert result.final_content == "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_repairs_model_context_without_shifting_save_turn_boundary(tmp_path):
|
||||
"""Historical backfill should not duplicate old tail messages on persist."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import _BACKFILL_CONTENT
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
response = LLMResponse(content="new answer", tool_calls=[], usage={})
|
||||
provider.chat_with_retry = AsyncMock(return_value=response)
|
||||
provider.chat_stream_with_retry = AsyncMock(return_value=response)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "old user", "timestamp": "2026-01-01T00:00:00"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_missing",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
"timestamp": "2026-01-01T00:00:01",
|
||||
},
|
||||
{"role": "assistant", "content": "old tail", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new prompt")
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "new answer"
|
||||
|
||||
request_messages = provider.chat_with_retry.await_args.kwargs["messages"]
|
||||
synthetic = [
|
||||
message
|
||||
for message in request_messages
|
||||
if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing"
|
||||
]
|
||||
assert len(synthetic) == 1
|
||||
assert synthetic[0]["content"] == _BACKFILL_CONTENT
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert [
|
||||
{
|
||||
key: value
|
||||
for key, value in message.items()
|
||||
if key in {"role", "content", "tool_call_id", "name", "tool_calls"}
|
||||
}
|
||||
for message in session_after.messages
|
||||
] == [
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_missing",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "old tail"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_backfill_only_mutates_model_context_not_returned_messages():
|
||||
"""Runner should repair orphaned tool calls for the model without rewriting result.messages."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _BACKFILL_CONTENT
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
initial_messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_missing",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "old tail"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
]
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
synthetic = [
|
||||
message
|
||||
for message in captured_messages
|
||||
if message.get("role") == "tool" and message.get("tool_call_id") == "call_missing"
|
||||
]
|
||||
assert len(synthetic) == 1
|
||||
assert synthetic[0]["content"] == _BACKFILL_CONTENT
|
||||
|
||||
assert [
|
||||
{
|
||||
key: value
|
||||
for key, value in message.items()
|
||||
if key in {"role", "content", "tool_call_id", "name", "tool_calls"}
|
||||
}
|
||||
for message in result.messages
|
||||
] == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_missing",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "old tail"},
|
||||
{"role": "user", "content": "new prompt"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Microcompact (stale tool result compaction)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_replaces_old_tool_results():
|
||||
"""Tool results beyond _MICROCOMPACT_KEEP_RECENT should be summarized."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
long_content = "x" * 600
|
||||
messages: list[dict] = [{"role": "system", "content": "sys"}]
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "read_file",
|
||||
"content": long_content,
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||
stale_count = total - _MICROCOMPACT_KEEP_RECENT
|
||||
compacted = [m for m in tool_msgs if "omitted from context" in str(m.get("content", ""))]
|
||||
preserved = [m for m in tool_msgs if m.get("content") == long_content]
|
||||
assert len(compacted) == stale_count
|
||||
assert len(preserved) == _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_preserves_short_results():
|
||||
"""Short tool results (< _MICROCOMPACT_MIN_CHARS) should not be replaced."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
messages: list[dict] = []
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "exec",
|
||||
"content": "short",
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
assert result is messages # no copy needed — all stale results are short
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_skips_non_compactable_tools():
|
||||
"""Non-compactable tools (e.g. 'message') should never be replaced."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
long_content = "y" * 1000
|
||||
messages: list[dict] = []
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "message", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "message",
|
||||
"content": long_content,
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
assert result is messages # no compactable tools found
|
||||
|
||||
|
||||
def test_governance_repairs_orphans_after_snip():
|
||||
"""After _snip_history clips an assistant+tool_calls, the second
|
||||
_drop_orphan_tool_results pass must clean up the resulting orphans."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old msg"},
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"id": "tc_old", "type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"}}]},
|
||||
{"role": "tool", "tool_call_id": "tc_old", "name": "search",
|
||||
"content": "old result"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
{"role": "user", "content": "new msg"},
|
||||
]
|
||||
|
||||
# Simulate snipping that keeps only the tail: drop the assistant with
|
||||
# tool_calls but keep its tool result (orphan).
|
||||
snipped = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "tool", "tool_call_id": "tc_old", "name": "search",
|
||||
"content": "old result"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
{"role": "user", "content": "new msg"},
|
||||
]
|
||||
|
||||
cleaned = AgentRunner._drop_orphan_tool_results(snipped)
|
||||
# The orphan tool result should be removed.
|
||||
assert not any(
|
||||
m.get("role") == "tool" and m.get("tool_call_id") == "tc_old"
|
||||
for m in cleaned
|
||||
)
|
||||
|
||||
|
||||
def test_governance_fallback_still_repairs_orphans():
|
||||
"""When full governance fails, the fallback must still run
|
||||
_drop_orphan_tool_results and _backfill_missing_tool_results."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
# Messages with an orphan tool result (no matching assistant tool_call).
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "tool", "tool_call_id": "orphan_tc", "name": "read",
|
||||
"content": "stale"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
repaired = AgentRunner._drop_orphan_tool_results(messages)
|
||||
repaired = AgentRunner._backfill_missing_tool_results(repaired)
|
||||
# Orphan tool result should be gone.
|
||||
assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired)
|
||||
def test_snip_history_preserves_user_message_after_truncation(monkeypatch):
|
||||
"""When _snip_history truncates messages and the only user message ends up
|
||||
outside the kept window, the method must recover the nearest user message
|
||||
so the resulting sequence is valid for providers like GLM (which reject
|
||||
system→assistant with error 1214).
|
||||
|
||||
This reproduces the exact scenario from the bug report:
|
||||
- Normal interaction: user asks, assistant calls tool, tool returns,
|
||||
assistant replies.
|
||||
- Injection adds a phantom user message, triggering more tool calls.
|
||||
- _snip_history activates, keeping only recent assistant/tool pairs.
|
||||
- The injected user message is in the truncated prefix and gets lost.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "assistant", "content": "previous reply"},
|
||||
{"role": "user", "content": ".nanobot的同目录"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "tc_1", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "tool output 1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "tc_2", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_2", "content": "tool output 2"},
|
||||
]
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=100,
|
||||
)
|
||||
|
||||
# Make estimate_prompt_tokens_chain report above budget so _snip_history activates.
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None))
|
||||
# Make kept window small: only the last 2 messages fit the budget.
|
||||
token_sizes = {
|
||||
"system": 0,
|
||||
"previous reply": 200,
|
||||
".nanobot的同目录": 80,
|
||||
"tool output 1": 80,
|
||||
"tool output 2": 80,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: token_sizes.get(str(msg.get("content")), 100),
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
# The first non-system message MUST be user (not assistant).
|
||||
non_system = [m for m in trimmed if m.get("role") != "system"]
|
||||
assert non_system, "trimmed should contain at least one non-system message"
|
||||
assert non_system[0]["role"] == "user", (
|
||||
f"First non-system message must be 'user', got '{non_system[0]['role']}'. "
|
||||
f"Roles: {[m['role'] for m in trimmed]}"
|
||||
)
|
||||
|
||||
|
||||
def test_snip_history_no_user_at_all_falls_back_gracefully(monkeypatch):
|
||||
"""Edge case: if non_system has zero user messages, _snip_history should
|
||||
still return a valid sequence (not crash or produce system→assistant)."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "assistant", "content": "reply"},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result"},
|
||||
{"role": "assistant", "content": "reply 2"},
|
||||
{"role": "tool", "tool_call_id": "tc_2", "content": "result 2"},
|
||||
]
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
context_window_tokens=2000,
|
||||
context_block_limit=100,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_a, **_kw: (500, None))
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runner.estimate_message_tokens",
|
||||
lambda msg: 100,
|
||||
)
|
||||
|
||||
trimmed = runner._snip_history(spec, messages)
|
||||
|
||||
# Should not crash. The result should still be a valid list.
|
||||
assert isinstance(trimmed, list)
|
||||
# Must have at least system.
|
||||
assert any(m.get("role") == "system" for m in trimmed)
|
||||
# The _enforce_role_alternation safety net must be able to fix whatever
|
||||
# _snip_history returns here — verify it produces a valid sequence.
|
||||
from nanobot.providers.base import LLMProvider
|
||||
fixed = LLMProvider._enforce_role_alternation(trimmed)
|
||||
non_system = [m for m in fixed if m["role"] != "system"]
|
||||
if non_system:
|
||||
assert non_system[0]["role"] in ("user", "tool"), (
|
||||
f"Safety net should ensure first non-system is user/tool, got {non_system[0]['role']}"
|
||||
)
|
||||
172
tests/agent/test_runner_hooks.py
Normal file
172
tests/agent/test_runner_hooks.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""Tests for AgentRunner hook lifecycle: ordering, streaming deltas,
|
||||
cached-token propagation, and hook context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_hooks_in_order():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
call_count = {"n": 0}
|
||||
events: list[tuple] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append(("before_iteration", context.iteration))
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
events.append((
|
||||
"before_execute_tools",
|
||||
context.iteration,
|
||||
[tc.name for tc in context.tool_calls],
|
||||
))
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append((
|
||||
"after_iteration",
|
||||
context.iteration,
|
||||
context.final_content,
|
||||
list(context.tool_results),
|
||||
list(context.tool_events),
|
||||
context.stop_reason,
|
||||
))
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
events.append(("finalize_content", context.iteration, content))
|
||||
return content.upper() if content else content
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=RecordingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "DONE"
|
||||
assert events == [
|
||||
("before_iteration", 0),
|
||||
("before_execute_tools", 0, ["list_dir"]),
|
||||
(
|
||||
"after_iteration",
|
||||
0,
|
||||
None,
|
||||
["tool result"],
|
||||
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
|
||||
None,
|
||||
),
|
||||
("before_iteration", 1),
|
||||
("finalize_content", 1, "done"),
|
||||
("after_iteration", 1, "DONE", [], [], "completed"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
streamed: list[str] = []
|
||||
endings: list[bool] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("he")
|
||||
await on_content_delta("llo")
|
||||
return LLMResponse(content="hello", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class StreamingHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
streamed.append(delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=StreamingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "hello"
|
||||
assert streamed == ["he", "llo"]
|
||||
assert endings == [False]
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
"""Hook context.usage should contain cached_tokens."""
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
captured_usage: list[dict] = []
|
||||
|
||||
class UsageHook(AgentHook):
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
captured_usage.append(dict(context.usage))
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=UsageHook(),
|
||||
))
|
||||
|
||||
assert len(captured_usage) == 1
|
||||
assert captured_usage[0]["cached_tokens"] == 150
|
||||
1038
tests/agent/test_runner_injections.py
Normal file
1038
tests/agent/test_runner_injections.py
Normal file
File diff suppressed because it is too large
Load Diff
161
tests/agent/test_runner_persistence.py
Normal file
161
tests/agent/test_runner_persistence.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""Tests for tool result persistence: large results, pruning, temp files, cleanup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="x" * 20_000)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
workspace=tmp_path,
|
||||
session_key="test:runner",
|
||||
max_tool_result_chars=2048,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert "[tool output persisted]" in tool_message["content"]
|
||||
assert "tool-results" in tool_message["content"]
|
||||
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
|
||||
|
||||
|
||||
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
root = tmp_path / ".nanobot" / "tool-results"
|
||||
old_bucket = root / "old_session"
|
||||
recent_bucket = root / "recent_session"
|
||||
old_bucket.mkdir(parents=True)
|
||||
recent_bucket.mkdir(parents=True)
|
||||
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
|
||||
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
|
||||
|
||||
stale = time.time() - (8 * 24 * 60 * 60)
|
||||
os.utime(old_bucket, (stale, stale))
|
||||
os.utime(old_bucket / "old.txt", (stale, stale))
|
||||
|
||||
persisted = maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert "[tool output persisted]" in persisted
|
||||
assert not old_bucket.exists()
|
||||
assert recent_bucket.exists()
|
||||
assert (root / "current_session" / "call_big.txt").exists()
|
||||
|
||||
|
||||
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
root = tmp_path / ".nanobot" / "tool-results"
|
||||
maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert (root / "current_session" / "call_big.txt").exists()
|
||||
assert list((root / "current_session").glob("*.tmp")) == []
|
||||
|
||||
|
||||
def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
warnings: list[str] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.helpers._cleanup_tool_result_buckets",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.helpers.logger.exception",
|
||||
lambda message, *args: warnings.append(message.format(*args)),
|
||||
)
|
||||
|
||||
persisted = maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert "[tool output persisted]" in persisted
|
||||
assert warnings and "Failed to clean stale tool result buckets" in warnings[0]
|
||||
async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert tool_message["content"] == "tool result"
|
||||
244
tests/agent/test_runner_safety.py
Normal file
244
tests/agent/test_runner_safety.py
Normal file
@ -0,0 +1,244 @@
|
||||
"""Tests for AgentRunner security: workspace violations, SSRF, shell guard, throttling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
async def test_runner_does_not_abort_on_workspace_violation_anymore():
|
||||
"""v2 behavior: workspace-bound rejections are *soft* tool errors.
|
||||
|
||||
Previously (PR #3493) any workspace boundary error became a fatal
|
||||
RuntimeError that aborted the turn. That silently killed legitimate
|
||||
workspace commands once the heuristic guard misfired (#3599 #3605), so
|
||||
we now hand the error back to the LLM as a recoverable tool result and
|
||||
rely on ``repeated_workspace_violation_error`` to throttle bypass loops.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(
|
||||
content="trying outside",
|
||||
tool_calls=[ToolCallRequest(
|
||||
id="call_1", name="read_file", arguments={"path": "/tmp/outside.md"},
|
||||
)],
|
||||
),
|
||||
LLMResponse(content="ok, telling the user instead", tool_calls=[]),
|
||||
])
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(
|
||||
side_effect=PermissionError(
|
||||
"Path /tmp/outside.md is outside allowed directory /workspace"
|
||||
)
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert provider.chat_with_retry.await_count == 2, (
|
||||
"workspace violation must NOT short-circuit the loop"
|
||||
)
|
||||
assert result.stop_reason != "tool_error"
|
||||
assert result.error is None
|
||||
assert result.final_content == "ok, telling the user instead"
|
||||
assert result.tool_events and result.tool_events[0]["status"] == "error"
|
||||
# Detail still carries the workspace_violation breadcrumb for telemetry,
|
||||
# but the runner did not raise.
|
||||
assert "workspace_violation" in result.tool_events[0]["detail"]
|
||||
|
||||
|
||||
def test_is_ssrf_violation_recognizes_private_url_blocks():
|
||||
"""SSRF rejections are classified separately from workspace boundaries."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
ssrf_msg = "Error: Command blocked by safety guard (internal/private URL detected)"
|
||||
assert AgentRunner._is_ssrf_violation(ssrf_msg) is True
|
||||
assert AgentRunner._is_ssrf_violation(
|
||||
"URL validation failed: Blocked: host resolves to private/internal address 192.168.1.2"
|
||||
) is True
|
||||
|
||||
# Workspace-bound markers are NOT classified as SSRF.
|
||||
assert AgentRunner._is_ssrf_violation(
|
||||
"Error: Command blocked by safety guard (path outside working dir)"
|
||||
) is False
|
||||
assert AgentRunner._is_ssrf_violation(
|
||||
"Path /tmp/x is outside allowed directory /ws"
|
||||
) is False
|
||||
# Deny / allowlist filter messages stay non-fatal too.
|
||||
assert AgentRunner._is_ssrf_violation(
|
||||
"Error: Command blocked by deny pattern filter"
|
||||
) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_non_retryable_hint_on_ssrf_violation():
|
||||
"""SSRF stays blocked, but the runtime gives the LLM a final chance to recover."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(
|
||||
content="curl-ing metadata",
|
||||
tool_calls=[ToolCallRequest(
|
||||
id="call_ssrf",
|
||||
name="exec",
|
||||
arguments={"command": "curl http://169.254.169.254"},
|
||||
)],
|
||||
),
|
||||
LLMResponse(
|
||||
content="I cannot access that private URL. Please share local files.",
|
||||
tool_calls=[],
|
||||
),
|
||||
])
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value=(
|
||||
"Error: Command blocked by safety guard (internal/private URL detected)"
|
||||
))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert provider.chat_with_retry.await_count == 2
|
||||
assert result.stop_reason == "completed"
|
||||
assert result.error is None
|
||||
assert result.final_content == "I cannot access that private URL. Please share local files."
|
||||
assert result.tool_events and result.tool_events[0]["detail"].startswith("ssrf_violation:")
|
||||
tool_messages = [m for m in result.messages if m.get("role") == "tool"]
|
||||
assert tool_messages
|
||||
assert "non-bypassable security boundary" in tool_messages[0]["content"]
|
||||
assert "Do not retry" in tool_messages[0]["content"]
|
||||
assert "tools.ssrfWhitelist" in tool_messages[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_lets_llm_recover_from_shell_guard_path_outside():
|
||||
"""Reporter scenario for #3599 / #3605 -- guard hit, agent recovers.
|
||||
|
||||
The shell `_guard_command` heuristic fires on `2>/dev/null`-style
|
||||
redirects and other shell idioms. Before v2 that abort'd the whole
|
||||
turn (silent hang on Telegram per #3605); now the LLM gets the soft
|
||||
error back and can finalize on the next iteration.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
if provider.chat_with_retry.await_count == 1:
|
||||
return LLMResponse(
|
||||
content="trying noisy cleanup",
|
||||
tool_calls=[ToolCallRequest(
|
||||
id="call_blocked",
|
||||
name="exec",
|
||||
arguments={"command": "rm scratch.txt 2>/dev/null"},
|
||||
)],
|
||||
)
|
||||
captured_second_call[:] = list(messages)
|
||||
return LLMResponse(content="recovered final answer", tool_calls=[])
|
||||
|
||||
provider.chat_with_retry = AsyncMock(side_effect=chat_with_retry)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(
|
||||
return_value="Error: Command blocked by safety guard (path outside working dir)"
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert provider.chat_with_retry.await_count == 2, (
|
||||
"guard hit must NOT short-circuit the loop -- LLM should get a second turn"
|
||||
)
|
||||
assert result.stop_reason != "tool_error"
|
||||
assert result.error is None
|
||||
assert result.final_content == "recovered final answer"
|
||||
assert result.tool_events and result.tool_events[0]["status"] == "error"
|
||||
# v2: detail keeps the breadcrumb but the runner did not raise.
|
||||
assert "workspace_violation" in result.tool_events[0]["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_throttles_repeated_workspace_bypass_attempts():
|
||||
"""#3493 motivation: stop the LLM bypass loop without aborting the turn.
|
||||
|
||||
LLM keeps switching tools (read_file -> exec cat -> python -c open(...))
|
||||
against the same outside path. After the soft retry budget is exhausted
|
||||
the runner replaces the tool result with a hard "stop trying" message
|
||||
so the model finally gives up and surfaces the boundary to the user.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
bypass_attempts = [
|
||||
ToolCallRequest(
|
||||
id=f"a{i}", name="exec",
|
||||
arguments={"command": f"cat /Users/x/Downloads/01.md # try {i}"},
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
responses: list[LLMResponse] = [
|
||||
LLMResponse(content=f"try {i}", tool_calls=[bypass_attempts[i]])
|
||||
for i in range(4)
|
||||
]
|
||||
responses.append(LLMResponse(content="ok telling user", tool_calls=[]))
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(side_effect=responses)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(
|
||||
return_value="Error: Command blocked by safety guard (path outside working dir)"
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
# All 4 bypass attempts surface to the LLM (no fatal abort), and the
|
||||
# runner finally completes once the LLM stops asking.
|
||||
assert result.stop_reason != "tool_error"
|
||||
assert result.error is None
|
||||
assert result.final_content == "ok telling user"
|
||||
# The third+ attempts must have been escalated -- look at the events.
|
||||
escalated = [
|
||||
ev for ev in result.tool_events
|
||||
if ev["status"] == "error"
|
||||
and ev["detail"].startswith("workspace_violation_escalated:")
|
||||
]
|
||||
assert escalated, (
|
||||
"expected at least one escalated workspace_violation event, got: "
|
||||
f"{result.tool_events}"
|
||||
)
|
||||
181
tests/agent/test_runner_tool_execution.py
Normal file
181
tests/agent/test_runner_tool_execution.py
Normal file
@ -0,0 +1,181 @@
|
||||
"""Tests for AgentRunner tool execution: batching, concurrency, exclusive tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
class _DelayTool(Tool):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
delay: float,
|
||||
read_only: bool,
|
||||
shared_events: list[str],
|
||||
exclusive: bool = False,
|
||||
):
|
||||
self._name = name
|
||||
self._delay = delay
|
||||
self._read_only = read_only
|
||||
self._shared_events = shared_events
|
||||
self._exclusive = exclusive
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict:
|
||||
return {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return self._read_only
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return self._exclusive
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self._shared_events.append(f"start:{self._name}")
|
||||
await asyncio.sleep(self._delay)
|
||||
self._shared_events.append(f"end:{self._name}")
|
||||
return self._name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
|
||||
tools.register(read_a)
|
||||
tools.register(read_b)
|
||||
tools.register(write_a)
|
||||
|
||||
runner = AgentRunner(MagicMock())
|
||||
await runner._execute_tools(
|
||||
AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
concurrent_tools=True,
|
||||
),
|
||||
[
|
||||
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||
ToolCallRequest(id="rw1", name="write_a", arguments={}),
|
||||
],
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
|
||||
assert "end:read_a" in shared_events and "end:read_b" in shared_events
|
||||
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
|
||||
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
|
||||
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
ddg_like = _DelayTool(
|
||||
"ddg_like",
|
||||
delay=0.01,
|
||||
read_only=True,
|
||||
shared_events=shared_events,
|
||||
exclusive=True,
|
||||
)
|
||||
tools.register(read_a)
|
||||
tools.register(ddg_like)
|
||||
tools.register(read_b)
|
||||
|
||||
runner = AgentRunner(MagicMock())
|
||||
await runner._execute_tools(
|
||||
AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
concurrent_tools=True,
|
||||
),
|
||||
[
|
||||
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||
ToolCallRequest(id="ddg1", name="ddg_like", arguments={}),
|
||||
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||
],
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
assert shared_events[0] == "start:read_a"
|
||||
assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like")
|
||||
assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_blocks_repeated_external_fetches():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_final_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] <= 3:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})],
|
||||
usage={},
|
||||
)
|
||||
captured_final_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="page content")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "research task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=4,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert tools.execute.await_count == 2
|
||||
blocked_tool_message = [
|
||||
msg for msg in captured_final_call
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3"
|
||||
][0]
|
||||
assert "repeated external lookup blocked" in blocked_tool_message["content"]
|
||||
@ -10,6 +10,7 @@ See: https://github.com/HKUDS/nanobot/issues/2966
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
@ -17,42 +18,47 @@ from unittest.mock import MagicMock, patch, AsyncMock
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_loop():
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
with patch.object(AgentLoop, "__init__", lambda self: None):
|
||||
loop = AgentLoop()
|
||||
loop.sessions = MagicMock()
|
||||
loop._pending_queues = {}
|
||||
loop._session_locks = {}
|
||||
loop._active_tasks = {}
|
||||
loop._concurrency_gate = None
|
||||
loop._RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
loop._PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||
loop.bus = MagicMock()
|
||||
loop.bus.publish_outbound = AsyncMock()
|
||||
loop.bus.publish_inbound = AsyncMock()
|
||||
loop.commands = MagicMock()
|
||||
loop.commands.dispatch_priority = AsyncMock(return_value=None)
|
||||
return loop
|
||||
def _make_provider():
|
||||
"""Create an LLM provider mock with required attributes."""
|
||||
from types import SimpleNamespace
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = SimpleNamespace(max_tokens=4096, temperature=0.1, reasoning_effort=None)
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
return provider
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
"""Create a real AgentLoop with mocked provider — avoids patching __init__."""
|
||||
bus = MessageBus()
|
||||
provider = _make_provider()
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||
|
||||
|
||||
class TestStopPreservesContext:
|
||||
"""Verify that /stop restores partial context via checkpoint."""
|
||||
|
||||
def test_restore_checkpoint_method_exists(self, mock_loop):
|
||||
def test_restore_checkpoint_method_exists(self, tmp_path):
|
||||
"""AgentLoop should have _restore_runtime_checkpoint."""
|
||||
assert hasattr(mock_loop, "_restore_runtime_checkpoint")
|
||||
loop = _make_loop(tmp_path)
|
||||
assert hasattr(loop, "_restore_runtime_checkpoint")
|
||||
|
||||
def test_checkpoint_key_constant(self, mock_loop):
|
||||
def test_checkpoint_key_constant(self, tmp_path):
|
||||
"""The runtime checkpoint key should be defined."""
|
||||
assert mock_loop._RUNTIME_CHECKPOINT_KEY == "runtime_checkpoint"
|
||||
loop = _make_loop(tmp_path)
|
||||
assert loop._RUNTIME_CHECKPOINT_KEY == "runtime_checkpoint"
|
||||
|
||||
def test_cancel_dispatch_restores_checkpoint(self, mock_loop):
|
||||
def test_cancel_dispatch_restores_checkpoint(self, tmp_path):
|
||||
"""When a task is cancelled, the checkpoint should be restored."""
|
||||
# Create a mock session with a checkpoint
|
||||
loop = _make_loop(tmp_path)
|
||||
session = MagicMock()
|
||||
session.metadata = {
|
||||
"runtime_checkpoint": {
|
||||
@ -74,14 +80,11 @@ class TestStopPreservesContext:
|
||||
session.messages = [
|
||||
{"role": "user", "content": "Search for something"},
|
||||
]
|
||||
mock_loop.sessions.get_or_create.return_value = session
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
|
||||
# The restore method should add checkpoint messages to session history
|
||||
restored = mock_loop._restore_runtime_checkpoint(session)
|
||||
restored = loop._restore_runtime_checkpoint(session)
|
||||
assert restored is True
|
||||
# After restore, session should have more messages
|
||||
assert len(session.messages) > 1
|
||||
# The checkpoint should be cleared
|
||||
assert "runtime_checkpoint" not in session.metadata
|
||||
|
||||
|
||||
|
||||
558
tests/agent/test_subagent_lifecycle.py
Normal file
558
tests/agent/test_subagent_lifecycle.py
Normal file
@ -0,0 +1,558 @@
|
||||
"""Tests for SubagentManager lifecycle — spawn, run, announce, cancel."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunResult
|
||||
from nanobot.agent.subagent import (
|
||||
SubagentManager,
|
||||
SubagentStatus,
|
||||
_SubagentHook,
|
||||
)
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _manager(tmp_path: Path, **kw) -> SubagentManager:
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
defaults = dict(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=MessageBus(),
|
||||
model="test-model",
|
||||
max_tool_result_chars=16_000,
|
||||
)
|
||||
defaults.update(kw)
|
||||
return SubagentManager(**defaults)
|
||||
|
||||
|
||||
def _make_hook_context(**overrides) -> AgentHookContext:
|
||||
defaults = dict(
|
||||
iteration=1,
|
||||
tool_calls=[],
|
||||
tool_events=[],
|
||||
messages=[],
|
||||
usage={},
|
||||
error=None,
|
||||
stop_reason="completed",
|
||||
final_content="ok",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return AgentHookContext(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentStatus defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentStatus:
|
||||
def test_defaults(self):
|
||||
s = SubagentStatus(
|
||||
task_id="abc", label="test", task_description="do stuff",
|
||||
started_at=time.monotonic(),
|
||||
)
|
||||
assert s.phase == "initializing"
|
||||
assert s.iteration == 0
|
||||
assert s.tool_events == []
|
||||
assert s.usage == {}
|
||||
assert s.stop_reason is None
|
||||
assert s.error is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetProvider:
|
||||
def test_updates_provider_model_runner(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
new_provider = MagicMock(spec=LLMProvider)
|
||||
sm.set_provider(new_provider, "new-model")
|
||||
assert sm.provider is new_provider
|
||||
assert sm.model == "new-model"
|
||||
assert sm.runner.provider is new_provider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# spawn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSpawn:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_string_with_task_id(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="done", messages=[], stop_reason="completed",
|
||||
))
|
||||
result = await sm.spawn("do something")
|
||||
assert "started" in result
|
||||
assert "id:" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_task_in_running_tasks(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task", session_key="s1")
|
||||
assert len(sm._running_tasks) == 1
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(sm._running_tasks) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_status(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="done", messages=[], stop_reason="completed",
|
||||
))
|
||||
await sm.spawn("my task")
|
||||
await asyncio.sleep(0.1)
|
||||
# Status cleaned up after task completes
|
||||
assert len(sm._task_statuses) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_in_session_tasks(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task", session_key="s1")
|
||||
assert "s1" in sm._session_tasks
|
||||
assert len(sm._session_tasks["s1"]) == 1
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert "s1" not in sm._session_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_key_no_registration(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task")
|
||||
assert len(sm._session_tasks) == 0
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_label_defaults_to_truncated_task(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
long_task = "A" * 50
|
||||
await sm.spawn(long_task, session_key="s1")
|
||||
status = next(iter(sm._task_statuses.values()))
|
||||
assert status.label == long_task[:30] + "..."
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_label(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task", label="Custom Label", session_key="s1")
|
||||
status = next(iter(sm._task_statuses.values()))
|
||||
assert status.label == "Custom Label"
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_callback_removes_all_entries(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="done", messages=[], stop_reason="completed",
|
||||
))
|
||||
await sm.spawn("task", session_key="s1")
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(sm._running_tasks) == 0
|
||||
assert len(sm._task_statuses) == 0
|
||||
assert len(sm._session_tasks) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_subagent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunSubagent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_run(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="Task done!", messages=[], stop_reason="completed",
|
||||
))
|
||||
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||
await sm._run_subagent(
|
||||
"t1", "do task", "label",
|
||||
{"channel": "cli", "chat_id": "direct"},
|
||||
SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic()),
|
||||
)
|
||||
mock_announce.assert_called_once()
|
||||
assert mock_announce.call_args.args[-2] == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_run(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content=None, messages=[], stop_reason="tool_error",
|
||||
tool_events=[{"name": "read_file", "status": "error", "detail": "not found"}],
|
||||
))
|
||||
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||
await sm._run_subagent(
|
||||
"t1", "do task", "label",
|
||||
{"channel": "cli", "chat_id": "direct"}, status,
|
||||
)
|
||||
assert mock_announce.call_args.args[-2] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_run(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||
with patch.object(sm, "_announce_result", new_callable=AsyncMock) as mock_announce:
|
||||
await sm._run_subagent(
|
||||
"t1", "do task", "label",
|
||||
{"channel": "cli", "chat_id": "direct"}, status,
|
||||
)
|
||||
assert status.phase == "error"
|
||||
assert "LLM down" in status.error
|
||||
assert mock_announce.call_args.args[-2] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_updated_on_success(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="ok", messages=[], stop_reason="completed",
|
||||
))
|
||||
status = SubagentStatus(task_id="t1", label="label", task_description="do task", started_at=time.monotonic())
|
||||
with patch.object(sm, "_announce_result", new_callable=AsyncMock):
|
||||
await sm._run_subagent(
|
||||
"t1", "do task", "label",
|
||||
{"channel": "cli", "chat_id": "direct"}, status,
|
||||
)
|
||||
assert status.phase == "done"
|
||||
assert status.stop_reason == "completed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _announce_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnnounceResult:
|
||||
@pytest.mark.asyncio
|
||||
async def test_publishes_inbound_message(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result text",
|
||||
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||
)
|
||||
|
||||
assert len(published) == 1
|
||||
msg = published[0]
|
||||
assert msg.channel == "system"
|
||||
assert msg.sender_id == "subagent"
|
||||
assert msg.metadata["injected_event"] == "subagent_result"
|
||||
assert msg.metadata["subagent_task_id"] == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_override(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result",
|
||||
{"channel": "telegram", "chat_id": "123", "session_key": "s1"}, "ok",
|
||||
)
|
||||
|
||||
assert published[0].session_key_override == "s1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_override_fallback(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result",
|
||||
{"channel": "telegram", "chat_id": "123"}, "ok",
|
||||
)
|
||||
|
||||
assert published[0].session_key_override == "telegram:123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ok_status_text(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result",
|
||||
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||
)
|
||||
|
||||
assert "completed successfully" in published[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_status_text(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "error details",
|
||||
{"channel": "cli", "chat_id": "direct"}, "error",
|
||||
)
|
||||
|
||||
assert "failed" in published[0].content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_origin_message_id_in_metadata(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
published = []
|
||||
sm.bus.publish_inbound = AsyncMock(side_effect=lambda msg: published.append(msg))
|
||||
|
||||
await sm._announce_result(
|
||||
"t1", "label", "task", "result",
|
||||
{"channel": "cli", "chat_id": "direct"}, "ok",
|
||||
origin_message_id="msg-123",
|
||||
)
|
||||
|
||||
assert published[0].metadata["origin_message_id"] == "msg-123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_partial_progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatPartialProgress:
|
||||
def _make_result(self, tool_events=None, error=None):
|
||||
return MagicMock(tool_events=tool_events or [], error=error)
|
||||
|
||||
def test_completed_only(self):
|
||||
result = self._make_result(tool_events=[
|
||||
{"name": "read_file", "status": "ok", "detail": "file content"},
|
||||
{"name": "exec", "status": "ok", "detail": "output"},
|
||||
])
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Completed steps:" in text
|
||||
assert "read_file" in text
|
||||
assert "exec" in text
|
||||
|
||||
def test_failure_only(self):
|
||||
result = self._make_result(tool_events=[
|
||||
{"name": "read_file", "status": "error", "detail": "not found"},
|
||||
])
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Failure:" in text
|
||||
assert "not found" in text
|
||||
|
||||
def test_completed_and_failure(self):
|
||||
result = self._make_result(tool_events=[
|
||||
{"name": "read_file", "status": "ok", "detail": "content"},
|
||||
{"name": "exec", "status": "error", "detail": "timeout"},
|
||||
])
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Completed steps:" in text
|
||||
assert "Failure:" in text
|
||||
|
||||
def test_limited_to_last_three(self):
|
||||
result = self._make_result(tool_events=[
|
||||
{"name": f"tool_{i}", "status": "ok", "detail": f"result_{i}"}
|
||||
for i in range(5)
|
||||
])
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "tool_2" in text
|
||||
assert "tool_3" in text
|
||||
assert "tool_4" in text
|
||||
assert "tool_0" not in text
|
||||
assert "tool_1" not in text
|
||||
|
||||
def test_error_without_failure_event(self):
|
||||
result = self._make_result(
|
||||
tool_events=[{"name": "read_file", "status": "ok", "detail": "ok"}],
|
||||
error="Something went wrong",
|
||||
)
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Something went wrong" in text
|
||||
|
||||
def test_empty_events_with_error(self):
|
||||
result = self._make_result(error="Total failure")
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Total failure" in text
|
||||
|
||||
def test_empty_no_error_returns_fallback(self):
|
||||
result = self._make_result()
|
||||
text = SubagentManager._format_partial_progress(result)
|
||||
assert "Error" in text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cancel_by_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelBySession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancels_running_tasks(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("task1", session_key="s1")
|
||||
await sm.spawn("task2", session_key="s1")
|
||||
assert len(sm._session_tasks.get("s1", set())) == 2
|
||||
|
||||
count = await sm.cancel_by_session("s1")
|
||||
assert count == 2
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tasks_returns_zero(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
count = await sm.cancel_by_session("nonexistent")
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_done_not_counted(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
sm.runner.run = AsyncMock(return_value=AgentRunResult(
|
||||
final_content="done", messages=[], stop_reason="completed",
|
||||
))
|
||||
await sm.spawn("task1", session_key="s1")
|
||||
await asyncio.sleep(0.1) # Wait for completion
|
||||
|
||||
count = await sm.cancel_by_session("s1")
|
||||
assert count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_running_count / get_running_count_by_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunningCounts:
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_count_zero(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
assert sm.get_running_count() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_count_tracks_tasks(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
block = asyncio.Event()
|
||||
async def _slow_run(spec):
|
||||
await block.wait()
|
||||
return AgentRunResult(final_content="done", messages=[], stop_reason="completed")
|
||||
sm.runner.run = _slow_run
|
||||
|
||||
await sm.spawn("t1", session_key="s1")
|
||||
await sm.spawn("t2", session_key="s1")
|
||||
assert sm.get_running_count() == 2
|
||||
assert sm.get_running_count_by_session("s1") == 2
|
||||
|
||||
block.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert sm.get_running_count() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_count_by_session_nonexistent(self, tmp_path):
|
||||
sm = _manager(tmp_path)
|
||||
assert sm.get_running_count_by_session("nonexistent") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _SubagentHook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentHook:
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_execute_tools_logs(self, tmp_path):
|
||||
hook = _SubagentHook("t1")
|
||||
tool_call = MagicMock()
|
||||
tool_call.name = "read_file"
|
||||
tool_call.arguments = {"path": "/tmp/test"}
|
||||
ctx = _make_hook_context(tool_calls=[tool_call])
|
||||
# Should not raise
|
||||
await hook.before_execute_tools(ctx)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_updates_status(self):
|
||||
status = SubagentStatus(
|
||||
task_id="t1", label="test", task_description="do", started_at=time.monotonic(),
|
||||
)
|
||||
hook = _SubagentHook("t1", status)
|
||||
ctx = _make_hook_context(
|
||||
iteration=3,
|
||||
tool_events=[{"name": "read_file", "status": "ok", "detail": ""}],
|
||||
usage={"prompt_tokens": 100},
|
||||
)
|
||||
await hook.after_iteration(ctx)
|
||||
assert status.iteration == 3
|
||||
assert len(status.tool_events) == 1
|
||||
assert status.usage == {"prompt_tokens": 100}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_no_status_noop(self):
|
||||
hook = _SubagentHook("t1", status=None)
|
||||
ctx = _make_hook_context(iteration=5)
|
||||
# Should not raise
|
||||
await hook.after_iteration(ctx)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_iteration_sets_error(self):
|
||||
status = SubagentStatus(
|
||||
task_id="t1", label="test", task_description="do", started_at=time.monotonic(),
|
||||
)
|
||||
hook = _SubagentHook("t1", status)
|
||||
ctx = _make_hook_context(error="something broke")
|
||||
await hook.after_iteration(ctx)
|
||||
assert status.error == "something broke"
|
||||
Loading…
x
Reference in New Issue
Block a user