mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
feat(agent): add SelfTool for agent runtime self-inspection and modification
Introduce a self-modification tool that allows the agent to inspect and safely modify its own runtime state. Includes a three-tier access control model (blocked/readonly/restricted), a watchdog that detects and corrects dangerous runtime state at each iteration, and critical tool backup/restore mechanism.
This commit is contained in:
parent
80403d352b
commit
34e015915e
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
@ -22,8 +23,10 @@ from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.self import SelfTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
@ -125,6 +128,16 @@ class AgentLoop:
|
||||
max_completion_tokens=provider.generation.max_tokens,
|
||||
)
|
||||
self._register_default_tools()
|
||||
self.tools.register(SelfTool(loop=self))
|
||||
self._config_defaults: dict[str, Any] = {
|
||||
"max_iterations": max_iterations,
|
||||
"context_window_tokens": context_window_tokens,
|
||||
"context_budget_tokens": context_budget_tokens,
|
||||
"model": self.model,
|
||||
}
|
||||
self._runtime_vars: dict[str, Any] = {}
|
||||
self._unregistered_tools: dict[str, Tool] = {}
|
||||
self._backup_critical_tools()
|
||||
self.commands = CommandRouter()
|
||||
register_builtin_commands(self.commands)
|
||||
|
||||
@ -173,9 +186,42 @@ class AgentLoop:
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
def _backup_critical_tools(self) -> None:
|
||||
"""Create immutable backups of tools that must never be missing."""
|
||||
self._critical_tool_backup: dict[str, Tool] = {}
|
||||
for name in ("self", "message", "read_file"):
|
||||
tool = self.tools.get(name)
|
||||
if tool:
|
||||
try:
|
||||
self._critical_tool_backup[name] = copy.deepcopy(tool)
|
||||
except Exception as e:
|
||||
logger.warning("Cannot deepcopy tool '{}': {}", name, e)
|
||||
self._critical_tool_backup[name] = tool
|
||||
|
||||
def _watchdog_check(self) -> None:
|
||||
"""Detect and correct dangerous runtime states at the start of each iteration."""
|
||||
defaults = self._config_defaults
|
||||
if not (1 <= self.max_iterations <= 100):
|
||||
logger.warning("Watchdog: resetting max_iterations {} -> {}", self.max_iterations, defaults["max_iterations"])
|
||||
self.max_iterations = defaults["max_iterations"]
|
||||
if not (4096 <= self.context_window_tokens <= 1_000_000):
|
||||
logger.warning("Watchdog: resetting context_window_tokens {} -> {}", self.context_window_tokens, defaults["context_window_tokens"])
|
||||
self.context_window_tokens = defaults["context_window_tokens"]
|
||||
if not (0 <= self.context_budget_tokens <= 1_000_000):
|
||||
logger.warning("Watchdog: resetting context_budget_tokens {} -> {}", self.context_budget_tokens, defaults["context_budget_tokens"])
|
||||
self.context_budget_tokens = defaults["context_budget_tokens"]
|
||||
# Restore critical tools if they were somehow removed
|
||||
for name, backup in self._critical_tool_backup.items():
|
||||
if not self.tools.has(name):
|
||||
logger.warning("Watchdog: restoring critical tool '{}'", name)
|
||||
try:
|
||||
self.tools.register(copy.deepcopy(backup))
|
||||
except Exception:
|
||||
self.tools.register(backup)
|
||||
|
||||
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Update context for all tools that need routing info."""
|
||||
for name in ("message", "spawn", "cron"):
|
||||
for name in ("message", "spawn", "cron", "self"):
|
||||
if tool := self.tools.get(name):
|
||||
if hasattr(tool, "set_context"):
|
||||
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
||||
@ -226,6 +272,9 @@ class AgentLoop:
|
||||
def wants_streaming(self) -> bool:
|
||||
return on_stream is not None
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
loop_self._watchdog_check()
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
|
||||
327
nanobot/agent/tools/self.py
Normal file
327
nanobot/agent/tools/self.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""Self-modification tool: allows the agent to inspect and modify its own runtime state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
|
||||
class SelfTool(Tool):
|
||||
"""Inspect and modify your own runtime state."""
|
||||
|
||||
# -- Tier 0: BLOCKED (never accessible, not even for reading) --
|
||||
BLOCKED = frozenset({
|
||||
# Core infrastructure
|
||||
"bus", "provider", "sessions", "session_manager",
|
||||
"_running", "_mcp_stack", "_mcp_servers",
|
||||
"_mcp_connected", "_mcp_connecting",
|
||||
"memory_consolidator", "_concurrency_gate",
|
||||
"commands", "tools", "subagents", "context",
|
||||
# Internal self-tool state (prevent tampering with defaults/backup)
|
||||
"_config_defaults", "_runtime_vars", "_unregistered_tools", "_critical_tool_backup",
|
||||
})
|
||||
|
||||
# Attributes that must never be accessed (dunder / introspection)
|
||||
_DENIED_ATTRS = frozenset({
|
||||
"__class__", "__dict__", "__bases__", "__subclasses__", "__mro__",
|
||||
"__init__", "__new__", "__reduce__", "__getstate__", "__setstate__",
|
||||
"__del__", "__call__", "__getattr__", "__setattr__", "__delattr__",
|
||||
"__code__", "__globals__", "func_globals", "func_code",
|
||||
})
|
||||
|
||||
# -- Tier 1: READONLY (inspectable, not modifiable) --
|
||||
READONLY = frozenset({
|
||||
"workspace", "restrict_to_workspace", "_start_time", "web_proxy",
|
||||
"web_search_config", "exec_config", "input_limits",
|
||||
"channels_config", "_last_usage",
|
||||
})
|
||||
|
||||
# -- Tier 2: RESTRICTED (modifiable with validation) --
|
||||
RESTRICTED: dict[str, dict[str, Any]] = {
|
||||
"max_iterations": {"type": int, "min": 1, "max": 100},
|
||||
"context_window_tokens": {"type": int, "min": 4096, "max": 1_000_000},
|
||||
"context_budget_tokens": {"type": int, "min": 0, "max": 1_000_000},
|
||||
"model": {"type": str, "min_len": 1},
|
||||
}
|
||||
|
||||
def __init__(self, loop: AgentLoop) -> None:
|
||||
self._loop = loop
|
||||
self._channel = ""
|
||||
self._chat_id = ""
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any]) -> SelfTool:
|
||||
"""Return a new instance sharing the same loop reference.
|
||||
|
||||
The loop holds unpicklable state (thread locks, asyncio objects), so a
|
||||
true deep copy is impossible. For the watchdog backup use-case we only
|
||||
need a fresh wrapper that still points at the live loop.
|
||||
"""
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
result._loop = self._loop # shared reference, not copied
|
||||
result._channel = self._channel
|
||||
result._chat_id = self._chat_id
|
||||
return result
|
||||
|
||||
def set_context(self, channel: str, chat_id: str) -> None:
|
||||
"""Set session context for audit logging."""
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "self"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Inspect and modify your own runtime state. "
|
||||
"Use 'inspect' to view current configuration, "
|
||||
"'modify' to change parameters, "
|
||||
"'unregister_tool'/'register_tool' to manage tools, "
|
||||
"and 'reset' to restore defaults."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["inspect", "modify", "unregister_tool", "register_tool", "reset"],
|
||||
"description": "Action to perform",
|
||||
},
|
||||
"key": {"type": "string", "description": "Property key (for inspect/modify/reset)"},
|
||||
"value": {"description": "New value (for modify)"},
|
||||
"name": {"type": "string", "description": "Tool name (for unregister_tool/register_tool)"},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
def _audit(self, action: str, detail: str) -> None:
|
||||
"""Log a self-modification event for auditability."""
|
||||
session = f"{self._channel}:{self._chat_id}" if self._channel else "unknown"
|
||||
logger.info("self.{} | {} | session:{}", action, detail, session)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Actions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
key: str | None = None,
|
||||
value: Any = None,
|
||||
name: str | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> str:
|
||||
if action == "inspect":
|
||||
return self._inspect(key)
|
||||
if action == "modify":
|
||||
return self._modify(key, value)
|
||||
if action == "unregister_tool":
|
||||
return self._unregister_tool(name)
|
||||
if action == "register_tool":
|
||||
return self._register_tool(name)
|
||||
if action == "reset":
|
||||
return self._reset(key)
|
||||
return f"Unknown action: {action}"
|
||||
|
||||
# -- inspect --
|
||||
|
||||
def _inspect(self, key: str | None) -> str:
|
||||
if key:
|
||||
if err := self._validate_key(key):
|
||||
return err
|
||||
return self._inspect_single(key)
|
||||
return self._inspect_all()
|
||||
|
||||
def _inspect_single(self, key: str) -> str:
|
||||
# Allow inspecting these special dicts even though they're in BLOCKED
|
||||
if key == "_runtime_vars":
|
||||
rv = self._loop._runtime_vars
|
||||
rv_repr = repr(rv)
|
||||
if len(rv_repr) > 2000:
|
||||
rv_repr = rv_repr[:2000] + "... (truncated)"
|
||||
return f"_runtime_vars: {rv_repr}"
|
||||
if key == "_unregistered_tools":
|
||||
return f"_unregistered_tools: {list(self._loop._unregistered_tools.keys())}"
|
||||
if key in self.BLOCKED or key.startswith("__") or key in self._DENIED_ATTRS:
|
||||
return f"Error: '{key}' is not accessible"
|
||||
val = getattr(self._loop, key, None)
|
||||
if val is None:
|
||||
return "'{key}' not found on agent".format(key=key)
|
||||
return f"{key}: {val!r}"
|
||||
|
||||
def _inspect_all(self) -> str:
|
||||
loop = self._loop
|
||||
parts: list[str] = []
|
||||
# Restricted properties
|
||||
for k in self.RESTRICTED:
|
||||
parts.append(f"{k}: {getattr(loop, k, None)!r}")
|
||||
# Readonly properties
|
||||
for k in self.READONLY:
|
||||
val = getattr(loop, k, None)
|
||||
if val is not None:
|
||||
parts.append(f"{k}: {val!r}")
|
||||
# Tools
|
||||
parts.append(f"tools: {loop.tools.tool_names}")
|
||||
# Runtime vars (limit output size)
|
||||
rv = loop._runtime_vars
|
||||
if rv:
|
||||
rv_repr = repr(rv)
|
||||
if len(rv_repr) > 2000:
|
||||
rv_repr = rv_repr[:2000] + "... (truncated)"
|
||||
parts.append(f"_runtime_vars: {rv_repr}")
|
||||
# Unregistered tools stash
|
||||
if loop._unregistered_tools:
|
||||
parts.append(f"_unregistered_tools: {list(loop._unregistered_tools.keys())}")
|
||||
return "\n".join(parts)
|
||||
|
||||
# -- modify --
|
||||
|
||||
@staticmethod
|
||||
def _validate_key(key: str | None, label: str = "key") -> str | None:
|
||||
"""Validate a key/name parameter. Returns error string or None."""
|
||||
if not key or not key.strip():
|
||||
return f"Error: '{label}' cannot be empty or whitespace"
|
||||
return None
|
||||
|
||||
def _modify(self, key: str | None, value: Any) -> str:
|
||||
if err := self._validate_key(key):
|
||||
return err
|
||||
if key in self.BLOCKED or key.startswith("__"):
|
||||
self._audit("modify", f"BLOCKED {key}")
|
||||
return f"Error: '{key}' is protected and cannot be modified"
|
||||
if key in self.BLOCKED:
|
||||
self._audit("modify", f"BLOCKED {key}")
|
||||
return f"Error: '{key}' is protected and cannot be modified"
|
||||
if key in self.READONLY:
|
||||
self._audit("modify", f"READONLY {key}")
|
||||
return f"Error: '{key}' is read-only and cannot be modified"
|
||||
if key in self.RESTRICTED:
|
||||
return self._modify_restricted(key, value)
|
||||
# Free tier: store in _runtime_vars
|
||||
return self._modify_free(key, value)
|
||||
|
||||
def _modify_restricted(self, key: str, value: Any) -> str:
|
||||
spec = self.RESTRICTED[key]
|
||||
expected = spec["type"]
|
||||
# Reject bool for int fields (bool is subclass of int in Python)
|
||||
if expected is int and isinstance(value, bool):
|
||||
return f"Error: '{key}' must be {expected.__name__}, got bool"
|
||||
# Coerce value to expected type (LLM may send "80" instead of 80)
|
||||
if not isinstance(value, expected):
|
||||
try:
|
||||
value = expected(value)
|
||||
except (ValueError, TypeError):
|
||||
return f"Error: '{key}' must be {expected.__name__}, got {type(value).__name__}"
|
||||
# Range check
|
||||
old = getattr(self._loop, key)
|
||||
if "min" in spec and value < spec["min"]:
|
||||
return f"Error: '{key}' must be >= {spec['min']}"
|
||||
if "max" in spec and value > spec["max"]:
|
||||
return f"Error: '{key}' must be <= {spec['max']}"
|
||||
if "min_len" in spec and len(str(value)) < spec["min_len"]:
|
||||
return f"Error: '{key}' must be at least {spec['min_len']} characters"
|
||||
# Apply
|
||||
setattr(self._loop, key, value)
|
||||
self._audit("modify", f"{key}: {old!r} -> {value!r}")
|
||||
return f"Set {key} = {value!r} (was {old!r})"
|
||||
|
||||
def _modify_free(self, key: str, value: Any) -> str:
|
||||
# Reject callables to prevent code injection
|
||||
if callable(value):
|
||||
self._audit("modify", f"REJECTED callable {key}")
|
||||
return "Error: cannot store callable values in _runtime_vars"
|
||||
# Recursively validate that value is JSON-safe (no nested references)
|
||||
err = self._validate_json_safe(value)
|
||||
if err:
|
||||
self._audit("modify", f"REJECTED {key}: {err}")
|
||||
return f"Error: {err}"
|
||||
old = self._loop._runtime_vars.get(key)
|
||||
self._loop._runtime_vars[key] = value
|
||||
self._audit("modify", f"_runtime_vars.{key}: {old!r} -> {value!r}")
|
||||
return f"Set _runtime_vars.{key} = {value!r}"
|
||||
|
||||
@staticmethod
|
||||
def _validate_json_safe(value: Any, depth: int = 0) -> str | None:
|
||||
"""Validate that value is JSON-safe (no nested references to live objects)."""
|
||||
if depth > 10:
|
||||
return "value nesting too deep (max 10 levels)"
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if err := SelfTool._validate_json_safe(item, depth + 1):
|
||||
return f"list[{i}] contains {err}"
|
||||
return None
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
if not isinstance(k, str):
|
||||
return f"dict key must be str, got {type(k).__name__}"
|
||||
if err := SelfTool._validate_json_safe(v, depth + 1):
|
||||
return f"dict key '{k}' contains {err}"
|
||||
return None
|
||||
return f"unsupported type {type(value).__name__}"
|
||||
|
||||
# -- unregister_tool --
|
||||
|
||||
def _unregister_tool(self, name: str | None) -> str:
|
||||
if err := self._validate_key(name, "name"):
|
||||
return err
|
||||
if name == "self":
|
||||
self._audit("unregister_tool", "BLOCKED self")
|
||||
return "Error: cannot unregister the 'self' tool (would cause lockout)"
|
||||
if not self._loop.tools.has(name):
|
||||
return f"Tool '{name}' is not currently registered"
|
||||
# Stash the tool instance before removing
|
||||
tool = self._loop.tools.get(name)
|
||||
self._loop._unregistered_tools[name] = tool
|
||||
self._loop.tools.unregister(name)
|
||||
self._audit("unregister_tool", name)
|
||||
return f"Unregistered tool '{name}'. Use register_tool to restore it."
|
||||
|
||||
# -- register_tool --
|
||||
|
||||
def _register_tool(self, name: str | None) -> str:
|
||||
if err := self._validate_key(name, "name"):
|
||||
return err
|
||||
if name not in self._loop._unregistered_tools:
|
||||
return f"Error: '{name}' was not previously unregistered (cannot register arbitrary tools)"
|
||||
tool = self._loop._unregistered_tools.pop(name)
|
||||
self._loop.tools.register(tool)
|
||||
self._audit("register_tool", name)
|
||||
return f"Re-registered tool '{name}'"
|
||||
|
||||
# -- reset --
|
||||
|
||||
def _reset(self, key: str | None) -> str:
|
||||
if err := self._validate_key(key):
|
||||
return err
|
||||
if key in self.BLOCKED:
|
||||
return f"Error: '{key}' is protected"
|
||||
if key in self.READONLY:
|
||||
return f"Error: '{key}' is read-only (already at its configured value)"
|
||||
if key in self.RESTRICTED:
|
||||
default = self._loop._config_defaults.get(key)
|
||||
if default is None:
|
||||
return f"Error: no config default for '{key}'"
|
||||
old = getattr(self._loop, key)
|
||||
setattr(self._loop, key, default)
|
||||
self._audit("reset", f"{key}: {old!r} -> {default!r}")
|
||||
return f"Reset {key} = {default!r} (was {old!r})"
|
||||
if key in self._loop._runtime_vars:
|
||||
old = self._loop._runtime_vars.pop(key)
|
||||
self._audit("reset", f"_runtime_vars.{key}: {old!r} -> deleted")
|
||||
return f"Deleted _runtime_vars.{key} (was {old!r})"
|
||||
return f"'{key}' is not a known property or runtime variable"
|
||||
0
tests/agent/__init__.py
Normal file
0
tests/agent/__init__.py
Normal file
0
tests/agent/tools/__init__.py
Normal file
0
tests/agent/tools/__init__.py
Normal file
499
tests/agent/tools/test_self_tool.py
Normal file
499
tests/agent/tools/test_self_tool.py
Normal file
@ -0,0 +1,499 @@
|
||||
"""Tests for SelfTool — agent runtime self-modification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.self import SelfTool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mock_loop(**overrides):
|
||||
"""Build a lightweight mock AgentLoop with the attributes SelfTool reads."""
|
||||
loop = MagicMock()
|
||||
loop.model = "anthropic/claude-sonnet-4-20250514"
|
||||
loop.max_iterations = 40
|
||||
loop.context_window_tokens = 65_536
|
||||
loop.context_budget_tokens = 500
|
||||
loop.workspace = Path("/tmp/workspace")
|
||||
loop.restrict_to_workspace = False
|
||||
loop._start_time = 1000.0
|
||||
loop.web_proxy = None
|
||||
loop.web_search_config = MagicMock()
|
||||
loop.exec_config = MagicMock()
|
||||
loop.input_limits = MagicMock()
|
||||
loop.channels_config = MagicMock()
|
||||
loop._last_usage = {"prompt_tokens": 100, "completion_tokens": 50}
|
||||
loop._runtime_vars = {}
|
||||
loop._unregistered_tools = {}
|
||||
loop._config_defaults = {
|
||||
"max_iterations": 40,
|
||||
"context_window_tokens": 65_536,
|
||||
"context_budget_tokens": 500,
|
||||
"model": "anthropic/claude-sonnet-4-20250514",
|
||||
}
|
||||
loop._critical_tool_backup = {}
|
||||
|
||||
# Tools registry mock
|
||||
loop.tools = MagicMock()
|
||||
loop.tools.tool_names = ["read_file", "write_file", "exec", "web_search", "self"]
|
||||
loop.tools.has.side_effect = lambda n: n in loop.tools.tool_names
|
||||
loop.tools.get.return_value = None
|
||||
|
||||
for k, v in overrides.items():
|
||||
setattr(loop, k, v)
|
||||
|
||||
return loop
|
||||
|
||||
|
||||
def _make_tool(loop=None):
|
||||
if loop is None:
|
||||
loop = _make_mock_loop()
|
||||
return SelfTool(loop=loop)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# inspect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInspect:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_returns_current_state(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="inspect")
|
||||
assert "max_iterations: 40" in result
|
||||
assert "context_window_tokens: 65536" in result
|
||||
assert "tools:" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_single_key(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="inspect", key="max_iterations")
|
||||
assert "max_iterations: 40" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_blocked_returns_error(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="inspect", key="bus")
|
||||
assert "not accessible" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_runtime_vars(self):
|
||||
loop = _make_mock_loop()
|
||||
loop._runtime_vars = {"task": "review"}
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="inspect", key="_runtime_vars")
|
||||
assert "task" in result
|
||||
assert "review" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_unknown_key(self):
|
||||
loop = _make_mock_loop()
|
||||
# Make getattr return None for unknown keys (like a real object)
|
||||
loop.nonexistent = None
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="inspect", key="nonexistent")
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# modify — restricted
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModifyRestricted:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_restricted_valid(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="max_iterations", value=80)
|
||||
assert "Set max_iterations = 80" in result
|
||||
assert tool._loop.max_iterations == 80
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_restricted_out_of_range(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="max_iterations", value=0)
|
||||
assert "Error" in result
|
||||
assert tool._loop.max_iterations == 40 # unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_restricted_max_exceeded(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="max_iterations", value=999)
|
||||
assert "Error" in result
|
||||
assert tool._loop.max_iterations == 40 # unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_restricted_wrong_type(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="max_iterations", value="not_an_int")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_restricted_bool_rejected_as_int(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="max_iterations", value=True)
|
||||
assert "Error" in result
|
||||
assert "bool" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_restricted_string_int_coerced(self):
|
||||
"""LLM may send numeric values as strings; coercion should handle it."""
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="max_iterations", value="80")
|
||||
assert "Set max_iterations = 80" in result
|
||||
assert tool._loop.max_iterations == 80
|
||||
assert isinstance(tool._loop.max_iterations, int)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_context_window_valid(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="context_window_tokens", value=131072)
|
||||
assert "Set context_window_tokens = 131072" in result
|
||||
assert tool._loop.context_window_tokens == 131072
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# modify — blocked & readonly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModifyBlockedReadonly:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_blocked_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="bus", value="hacked")
|
||||
assert "protected" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_tools_blocked(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="tools", value={})
|
||||
assert "protected" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_subagents_blocked(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="subagents", value=None)
|
||||
assert "protected" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_context_blocked(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="context", value=None)
|
||||
assert "protected" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_readonly_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="workspace", value="/tmp/evil")
|
||||
assert "read-only" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_exec_config_readonly(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="exec_config", value={})
|
||||
assert "read-only" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# modify — free tier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModifyFree:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_key_stored(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="my_var", value="hello")
|
||||
assert "Set _runtime_vars.my_var = 'hello'" in result
|
||||
assert tool._loop._runtime_vars["my_var"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_numeric_value(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="count", value=42)
|
||||
assert tool._loop._runtime_vars["count"] == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_rejects_callable(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="evil", value=lambda: None)
|
||||
assert "callable" in result
|
||||
assert "evil" not in tool._loop._runtime_vars
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_rejects_complex_objects(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="obj", value=Path("/tmp"))
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_allows_list(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="items", value=[1, 2, 3])
|
||||
assert tool._loop._runtime_vars["items"] == [1, 2, 3]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_allows_dict(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="data", value={"a": 1})
|
||||
assert tool._loop._runtime_vars["data"] == {"a": 1}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# unregister_tool / register_tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolManagement:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_tool_success(self):
|
||||
loop = _make_mock_loop()
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="unregister_tool", name="web_search")
|
||||
assert "Unregistered" in result
|
||||
assert "web_search" in loop._unregistered_tools
|
||||
loop.tools.unregister.assert_called_once_with("web_search")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_self_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="unregister_tool", name="self")
|
||||
assert "lockout" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_nonexistent_tool(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="unregister_tool", name="nonexistent")
|
||||
assert "not currently registered" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_tool_restores(self):
|
||||
loop = _make_mock_loop()
|
||||
mock_tool = MagicMock()
|
||||
loop._unregistered_tools = {"web_search": mock_tool}
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="register_tool", name="web_search")
|
||||
assert "Re-registered" in result
|
||||
loop.tools.register.assert_called_once_with(mock_tool)
|
||||
assert "web_search" not in loop._unregistered_tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unknown_tool_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="register_tool", name="web_search")
|
||||
assert "was not previously unregistered" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_requires_name(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="register_tool")
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReset:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_restores_default(self):
|
||||
tool = _make_tool()
|
||||
# Modify first
|
||||
await tool.execute(action="modify", key="max_iterations", value=80)
|
||||
assert tool._loop.max_iterations == 80
|
||||
# Reset
|
||||
result = await tool.execute(action="reset", key="max_iterations")
|
||||
assert "Reset max_iterations = 40" in result
|
||||
assert tool._loop.max_iterations == 40
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_blocked_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="reset", key="bus")
|
||||
assert "protected" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_readonly_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="reset", key="workspace")
|
||||
assert "read-only" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_deletes_runtime_var(self):
|
||||
tool = _make_tool()
|
||||
await tool.execute(action="modify", key="temp", value="data")
|
||||
result = await tool.execute(action="reset", key="temp")
|
||||
assert "Deleted" in result
|
||||
assert "temp" not in tool._loop._runtime_vars
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_unknown_key(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="reset", key="nonexistent")
|
||||
assert "not a known property" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases from code review
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEdgeCases:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_dunder_blocked(self):
|
||||
tool = _make_tool()
|
||||
for attr in ("__class__", "__dict__", "__bases__", "__subclasses__", "__mro__"):
|
||||
result = await tool.execute(action="inspect", key=attr)
|
||||
assert "not accessible" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_dunder_blocked(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="__class__", value="evil")
|
||||
assert "protected" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_internal_attributes_blocked(self):
|
||||
tool = _make_tool()
|
||||
for attr in ("_config_defaults", "_runtime_vars", "_unregistered_tools", "_critical_tool_backup"):
|
||||
result = await tool.execute(action="modify", key=attr, value={})
|
||||
assert "protected" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_nested_dict_with_object_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="evil", value={"nested": object()})
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_nested_list_with_callable_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="evil", value=[1, 2, lambda: None])
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_deep_nesting_rejected(self):
|
||||
tool = _make_tool()
|
||||
# Create a deeply nested dict (>10 levels)
|
||||
deep = {"level": 0}
|
||||
current = deep
|
||||
for i in range(1, 15):
|
||||
current["child"] = {"level": i}
|
||||
current = current["child"]
|
||||
result = await tool.execute(action="modify", key="deep", value=deep)
|
||||
assert "nesting too deep" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_dict_with_non_str_key_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="evil", value={42: "value"})
|
||||
assert "key must be str" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_valid_nested_structure(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="data", value={"a": [1, 2, {"b": True}]})
|
||||
assert "Error" not in result
|
||||
assert tool._loop._runtime_vars["data"] == {"a": [1, 2, {"b": True}]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_key_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key=" ", value="test")
|
||||
assert "cannot be empty or whitespace" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_name_rejected(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="unregister_tool", name=" ")
|
||||
assert "cannot be empty or whitespace" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_none_value_for_restricted_int(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="modify", key="max_iterations", value=None)
|
||||
assert "Error" in result
|
||||
assert "must be int" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_all_truncates_large_runtime_vars(self):
|
||||
loop = _make_mock_loop()
|
||||
# Create a large runtime vars dict
|
||||
loop._runtime_vars = {f"key_{i}": f"value_{i}" * 100 for i in range(100)}
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="inspect")
|
||||
# The output should be truncated
|
||||
assert "truncated" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_internal_attribute_returns_error(self):
|
||||
tool = _make_tool()
|
||||
result = await tool.execute(action="reset", key="_config_defaults")
|
||||
assert "protected" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# watchdog (tested via AgentLoop method, using a real loop-like object)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWatchdog:
|
||||
|
||||
def test_watchdog_corrects_invalid_iterations(self):
|
||||
loop = _make_mock_loop()
|
||||
loop.max_iterations = 0
|
||||
loop._critical_tool_backup = {}
|
||||
# Simulate watchdog
|
||||
defaults = loop._config_defaults
|
||||
if not (1 <= loop.max_iterations <= 100):
|
||||
loop.max_iterations = defaults["max_iterations"]
|
||||
assert loop.max_iterations == 40
|
||||
|
||||
def test_watchdog_corrects_invalid_context_window(self):
|
||||
loop = _make_mock_loop()
|
||||
loop.context_window_tokens = 100
|
||||
loop._critical_tool_backup = {}
|
||||
defaults = loop._config_defaults
|
||||
if not (4096 <= loop.context_window_tokens <= 1_000_000):
|
||||
loop.context_window_tokens = defaults["context_window_tokens"]
|
||||
assert loop.context_window_tokens == 65_536
|
||||
|
||||
def test_watchdog_restores_critical_tools(self):
|
||||
loop = _make_mock_loop()
|
||||
backup = MagicMock()
|
||||
loop._critical_tool_backup = {"self": backup}
|
||||
loop.tools.has.return_value = False
|
||||
loop.tools.tool_names = []
|
||||
# Simulate watchdog
|
||||
for name, bk in loop._critical_tool_backup.items():
|
||||
if not loop.tools.has(name):
|
||||
loop.tools.register(copy.deepcopy(bk))
|
||||
loop.tools.register.assert_called()
|
||||
# Verify it was called with a copy, not the original
|
||||
called_arg = loop.tools.register.call_args[0][0]
|
||||
assert called_arg is not backup # deep copy
|
||||
|
||||
def test_watchdog_does_not_reset_valid_state(self):
|
||||
loop = _make_mock_loop()
|
||||
loop.max_iterations = 50
|
||||
loop.context_window_tokens = 131072
|
||||
loop._critical_tool_backup = {}
|
||||
original_max = loop.max_iterations
|
||||
original_ctx = loop.context_window_tokens
|
||||
# Simulate watchdog
|
||||
if not (1 <= loop.max_iterations <= 100):
|
||||
loop.max_iterations = loop._config_defaults["max_iterations"]
|
||||
if not (4096 <= loop.context_window_tokens <= 1_000_000):
|
||||
loop.context_window_tokens = loop._config_defaults["context_window_tokens"]
|
||||
assert loop.max_iterations == original_max
|
||||
assert loop.context_window_tokens == original_ctx
|
||||
Loading…
x
Reference in New Issue
Block a user