fix(self_tool): address code review issues — dead code, None ambiguity, size limit, watchdog tests

- Remove duplicate BLOCKED check in _modify (dead code)
- Use hasattr() instead of None check to distinguish missing vs None-valued attributes
- Add 64-key cap on _runtime_vars to prevent unbounded memory growth
- Refactor watchdog tests to call actual _watchdog_check() instead of inline logic
This commit is contained in:
chengyongru 2026-03-28 00:02:10 +08:00 committed by chengyongru
parent 3684e410b2
commit c3f54088a6
3 changed files with 106 additions and 47 deletions

View File

@ -132,7 +132,6 @@ class AgentLoop:
self._config_defaults: dict[str, Any] = {
"max_iterations": max_iterations,
"context_window_tokens": context_window_tokens,
"context_budget_tokens": context_budget_tokens,
"model": self.model,
}
self._runtime_vars: dict[str, Any] = {}
@ -207,9 +206,6 @@ class AgentLoop:
if not (4096 <= self.context_window_tokens <= 1_000_000):
logger.warning("Watchdog: resetting context_window_tokens {} -> {}", self.context_window_tokens, defaults["context_window_tokens"])
self.context_window_tokens = defaults["context_window_tokens"]
if not (0 <= self.context_budget_tokens <= 1_000_000):
logger.warning("Watchdog: resetting context_budget_tokens {} -> {}", self.context_budget_tokens, defaults["context_budget_tokens"])
self.context_budget_tokens = defaults["context_budget_tokens"]
# Restore critical tools if they were somehow removed
for name, backup in self._critical_tool_backup.items():
if not self.tools.has(name):

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import copy
from typing import TYPE_CHECKING, Any
from loguru import logger
@ -47,10 +46,12 @@ class SelfTool(Tool):
RESTRICTED: dict[str, dict[str, Any]] = {
"max_iterations": {"type": int, "min": 1, "max": 100},
"context_window_tokens": {"type": int, "min": 4096, "max": 1_000_000},
"context_budget_tokens": {"type": int, "min": 0, "max": 1_000_000},
"model": {"type": str, "min_len": 1},
}
# Max number of elements (list items + dict entries summed recursively) per value
_MAX_VALUE_ELEMENTS = 1024
def __init__(self, loop: AgentLoop) -> None:
self._loop = loop
self._channel = ""
@ -157,10 +158,9 @@ class SelfTool(Tool):
return f"_unregistered_tools: {list(self._loop._unregistered_tools.keys())}"
if key in self.BLOCKED or key.startswith("__") or key in self._DENIED_ATTRS:
return f"Error: '{key}' is not accessible"
val = getattr(self._loop, key, None)
if val is None:
return "'{key}' not found on agent".format(key=key)
return f"{key}: {val!r}"
if not hasattr(self._loop, key):
return f"'{key}' not found on agent"
return f"{key}: {getattr(self._loop, key)!r}"
def _inspect_all(self) -> str:
loop = self._loop
@ -173,7 +173,8 @@ class SelfTool(Tool):
val = getattr(loop, k, None)
if val is not None:
parts.append(f"{k}: {val!r}")
# Tools
# Tools (intentionally exposed here for operational visibility,
# even though 'tools' is BLOCKED for single-key inspect/modify)
parts.append(f"tools: {loop.tools.tool_names}")
# Runtime vars (limit output size)
rv = loop._runtime_vars
@ -199,10 +200,7 @@ class SelfTool(Tool):
def _modify(self, key: str | None, value: Any) -> str:
if err := self._validate_key(key):
return err
if key in self.BLOCKED or key.startswith("__"):
self._audit("modify", f"BLOCKED {key}")
return f"Error: '{key}' is protected and cannot be modified"
if key in self.BLOCKED:
if key in self.BLOCKED or key.startswith("__") or key in self._DENIED_ATTRS:
self._audit("modify", f"BLOCKED {key}")
return f"Error: '{key}' is protected and cannot be modified"
if key in self.READONLY:
@ -248,28 +246,43 @@ class SelfTool(Tool):
if err:
self._audit("modify", f"REJECTED {key}: {err}")
return f"Error: {err}"
# Limit total keys to prevent unbounded memory growth
if key not in self._loop._runtime_vars and len(self._loop._runtime_vars) >= 64:
self._audit("modify", f"REJECTED {key}: max keys (64) reached")
return "Error: _runtime_vars is full (max 64 keys). Reset unused keys first."
old = self._loop._runtime_vars.get(key)
self._loop._runtime_vars[key] = value
self._audit("modify", f"_runtime_vars.{key}: {old!r} -> {value!r}")
return f"Set _runtime_vars.{key} = {value!r}"
@staticmethod
def _validate_json_safe(value: Any, depth: int = 0) -> str | None:
"""Validate that value is JSON-safe (no nested references to live objects)."""
@classmethod
def _validate_json_safe(cls, value: Any, depth: int = 0, elements: int = 0) -> str | None:
"""Validate that value is JSON-safe (no nested references to live objects).
Returns an error string if validation fails, or None if the value is safe.
``elements`` tracks the cumulative count of list items + dict entries to
enforce a per-value size cap.
"""
if depth > 10:
return "value nesting too deep (max 10 levels)"
if isinstance(value, (str, int, float, bool, type(None))):
return None
if isinstance(value, list):
elements += len(value)
if elements > cls._MAX_VALUE_ELEMENTS:
return f"value too large (max {cls._MAX_VALUE_ELEMENTS} total elements)"
for i, item in enumerate(value):
if err := SelfTool._validate_json_safe(item, depth + 1):
if err := cls._validate_json_safe(item, depth + 1, elements):
return f"list[{i}] contains {err}"
return None
if isinstance(value, dict):
elements += len(value)
if elements > cls._MAX_VALUE_ELEMENTS:
return f"value too large (max {cls._MAX_VALUE_ELEMENTS} total elements)"
for k, v in value.items():
if not isinstance(k, str):
return f"dict key must be str, got {type(k).__name__}"
if err := SelfTool._validate_json_safe(v, depth + 1):
if err := cls._validate_json_safe(v, depth + 1, elements):
return f"dict key '{k}' contains {err}"
return None
return f"unsupported type {type(value).__name__}"
@ -313,9 +326,9 @@ class SelfTool(Tool):
if key in self.READONLY:
return f"Error: '{key}' is read-only (already at its configured value)"
if key in self.RESTRICTED:
default = self._loop._config_defaults.get(key)
if default is None:
if key not in self._loop._config_defaults:
return f"Error: no config default for '{key}'"
default = self._loop._config_defaults[key]
old = getattr(self._loop, key)
setattr(self._loop, key, default)
self._audit("reset", f"{key}: {old!r} -> {default!r}")

View File

@ -2,12 +2,12 @@
from __future__ import annotations
import copy
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools.self import SelfTool
@ -47,6 +47,9 @@ def _make_mock_loop(**overrides):
loop.tools.has.side_effect = lambda n: n in loop.tools.tool_names
loop.tools.get.return_value = None
# Attach the real _watchdog_check method to the mock so tests exercise actual code
loop._watchdog_check = AgentLoop._watchdog_check.__get__(loop)
for k, v in overrides.items():
setattr(loop, k, v)
@ -94,13 +97,30 @@ class TestInspect:
assert "task" in result
assert "review" in result
@pytest.mark.asyncio
async def test_inspect_none_attribute_shows_value(self):
"""Attributes that are legitimately None should show their value, not 'not found'."""
tool = _make_tool()
# web_proxy is initialized as None in the mock
result = await tool.execute(action="inspect", key="web_proxy")
assert "web_proxy: None" in result
assert "not found" not in result
@pytest.mark.asyncio
async def test_inspect_unknown_key(self):
loop = _make_mock_loop()
# Make getattr return None for unknown keys (like a real object)
loop.nonexistent = None
tool = _make_tool(loop)
result = await tool.execute(action="inspect", key="nonexistent")
"""Use a real object (not MagicMock) so hasattr returns False for missing attrs."""
class MinimalLoop:
model = "test-model"
max_iterations = 40
context_window_tokens = 65536
context_budget_tokens = 500
_runtime_vars = {}
_unregistered_tools = {}
_config_defaults = {}
_critical_tool_backup = {}
loop = MinimalLoop()
tool = SelfTool(loop=loop)
result = await tool.execute(action="inspect", key="nonexistent_key")
assert "not found" in result
@ -248,6 +268,24 @@ class TestModifyFree:
result = await tool.execute(action="modify", key="data", value={"a": 1})
assert tool._loop._runtime_vars["data"] == {"a": 1}
@pytest.mark.asyncio
async def test_modify_free_rejects_when_max_keys_reached(self):
loop = _make_mock_loop()
loop._runtime_vars = {f"key_{i}": i for i in range(64)}
tool = _make_tool(loop)
result = await tool.execute(action="modify", key="overflow", value="data")
assert "full" in result
assert "overflow" not in loop._runtime_vars
@pytest.mark.asyncio
async def test_modify_free_allows_update_existing_key_at_max(self):
loop = _make_mock_loop()
loop._runtime_vars = {f"key_{i}": i for i in range(64)}
tool = _make_tool(loop)
result = await tool.execute(action="modify", key="key_0", value="updated")
assert "Error" not in result
assert loop._runtime_vars["key_0"] == "updated"
# ---------------------------------------------------------------------------
# unregister_tool / register_tool
@ -442,6 +480,33 @@ class TestEdgeCases:
result = await tool.execute(action="reset", key="_config_defaults")
assert "protected" in result
@pytest.mark.asyncio
async def test_modify_denied_attrs_non_dunder_blocked(self):
"""Non-dunder entries in _DENIED_ATTRS (e.g. func_globals) must be blocked."""
tool = _make_tool()
for attr in ("func_globals", "func_code"):
result = await tool.execute(action="modify", key=attr, value="evil")
assert "protected" in result, f"{attr} should be blocked"
@pytest.mark.asyncio
async def test_modify_free_value_too_large_rejected(self):
"""Values exceeding _MAX_VALUE_ELEMENTS should be rejected."""
tool = _make_tool()
big_list = list(range(2000))
result = await tool.execute(action="modify", key="big", value=big_list)
assert "too large" in result
assert "big" not in tool._loop._runtime_vars
@pytest.mark.asyncio
async def test_reset_with_none_default_succeeds(self):
"""Reset should work even if the config default is legitimately None."""
loop = _make_mock_loop()
loop._config_defaults["max_iterations"] = None
loop.max_iterations = 80
tool = _make_tool(loop)
result = await tool.execute(action="reset", key="max_iterations")
assert "Reset max_iterations = None" in result
# ---------------------------------------------------------------------------
# watchdog (tested via AgentLoop method, using a real loop-like object)
@ -452,20 +517,13 @@ class TestWatchdog:
def test_watchdog_corrects_invalid_iterations(self):
loop = _make_mock_loop()
loop.max_iterations = 0
loop._critical_tool_backup = {}
# Simulate watchdog
defaults = loop._config_defaults
if not (1 <= loop.max_iterations <= 100):
loop.max_iterations = defaults["max_iterations"]
loop._watchdog_check()
assert loop.max_iterations == 40
def test_watchdog_corrects_invalid_context_window(self):
loop = _make_mock_loop()
loop.context_window_tokens = 100
loop._critical_tool_backup = {}
defaults = loop._config_defaults
if not (4096 <= loop.context_window_tokens <= 1_000_000):
loop.context_window_tokens = defaults["context_window_tokens"]
loop._watchdog_check()
assert loop.context_window_tokens == 65_536
def test_watchdog_restores_critical_tools(self):
@ -474,10 +532,7 @@ class TestWatchdog:
loop._critical_tool_backup = {"self": backup}
loop.tools.has.return_value = False
loop.tools.tool_names = []
# Simulate watchdog
for name, bk in loop._critical_tool_backup.items():
if not loop.tools.has(name):
loop.tools.register(copy.deepcopy(bk))
loop._watchdog_check()
loop.tools.register.assert_called()
# Verify it was called with a copy, not the original
called_arg = loop.tools.register.call_args[0][0]
@ -487,13 +542,8 @@ class TestWatchdog:
loop = _make_mock_loop()
loop.max_iterations = 50
loop.context_window_tokens = 131072
loop._critical_tool_backup = {}
original_max = loop.max_iterations
original_ctx = loop.context_window_tokens
# Simulate watchdog
if not (1 <= loop.max_iterations <= 100):
loop.max_iterations = loop._config_defaults["max_iterations"]
if not (4096 <= loop.context_window_tokens <= 1_000_000):
loop.context_window_tokens = loop._config_defaults["context_window_tokens"]
loop._watchdog_check()
assert loop.max_iterations == original_max
assert loop.context_window_tokens == original_ctx