mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
fix(self_tool): address code review issues — dead code, None ambiguity, size limit, watchdog tests
- Remove duplicate BLOCKED check in _modify (dead code) - Use hasattr() instead of None check to distinguish missing vs None-valued attributes - Add 64-key cap on _runtime_vars to prevent unbounded memory growth - Refactor watchdog tests to call actual _watchdog_check() instead of inline logic
This commit is contained in:
parent
3684e410b2
commit
c3f54088a6
@ -132,7 +132,6 @@ class AgentLoop:
|
||||
self._config_defaults: dict[str, Any] = {
|
||||
"max_iterations": max_iterations,
|
||||
"context_window_tokens": context_window_tokens,
|
||||
"context_budget_tokens": context_budget_tokens,
|
||||
"model": self.model,
|
||||
}
|
||||
self._runtime_vars: dict[str, Any] = {}
|
||||
@ -207,9 +206,6 @@ class AgentLoop:
|
||||
if not (4096 <= self.context_window_tokens <= 1_000_000):
|
||||
logger.warning("Watchdog: resetting context_window_tokens {} -> {}", self.context_window_tokens, defaults["context_window_tokens"])
|
||||
self.context_window_tokens = defaults["context_window_tokens"]
|
||||
if not (0 <= self.context_budget_tokens <= 1_000_000):
|
||||
logger.warning("Watchdog: resetting context_budget_tokens {} -> {}", self.context_budget_tokens, defaults["context_budget_tokens"])
|
||||
self.context_budget_tokens = defaults["context_budget_tokens"]
|
||||
# Restore critical tools if they were somehow removed
|
||||
for name, backup in self._critical_tool_backup.items():
|
||||
if not self.tools.has(name):
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from loguru import logger
|
||||
@ -47,10 +46,12 @@ class SelfTool(Tool):
|
||||
RESTRICTED: dict[str, dict[str, Any]] = {
|
||||
"max_iterations": {"type": int, "min": 1, "max": 100},
|
||||
"context_window_tokens": {"type": int, "min": 4096, "max": 1_000_000},
|
||||
"context_budget_tokens": {"type": int, "min": 0, "max": 1_000_000},
|
||||
"model": {"type": str, "min_len": 1},
|
||||
}
|
||||
|
||||
# Max number of elements (list items + dict entries summed recursively) per value
|
||||
_MAX_VALUE_ELEMENTS = 1024
|
||||
|
||||
def __init__(self, loop: AgentLoop) -> None:
|
||||
self._loop = loop
|
||||
self._channel = ""
|
||||
@ -157,10 +158,9 @@ class SelfTool(Tool):
|
||||
return f"_unregistered_tools: {list(self._loop._unregistered_tools.keys())}"
|
||||
if key in self.BLOCKED or key.startswith("__") or key in self._DENIED_ATTRS:
|
||||
return f"Error: '{key}' is not accessible"
|
||||
val = getattr(self._loop, key, None)
|
||||
if val is None:
|
||||
return "'{key}' not found on agent".format(key=key)
|
||||
return f"{key}: {val!r}"
|
||||
if not hasattr(self._loop, key):
|
||||
return f"'{key}' not found on agent"
|
||||
return f"{key}: {getattr(self._loop, key)!r}"
|
||||
|
||||
def _inspect_all(self) -> str:
|
||||
loop = self._loop
|
||||
@ -173,7 +173,8 @@ class SelfTool(Tool):
|
||||
val = getattr(loop, k, None)
|
||||
if val is not None:
|
||||
parts.append(f"{k}: {val!r}")
|
||||
# Tools
|
||||
# Tools (intentionally exposed here for operational visibility,
|
||||
# even though 'tools' is BLOCKED for single-key inspect/modify)
|
||||
parts.append(f"tools: {loop.tools.tool_names}")
|
||||
# Runtime vars (limit output size)
|
||||
rv = loop._runtime_vars
|
||||
@ -199,10 +200,7 @@ class SelfTool(Tool):
|
||||
def _modify(self, key: str | None, value: Any) -> str:
|
||||
if err := self._validate_key(key):
|
||||
return err
|
||||
if key in self.BLOCKED or key.startswith("__"):
|
||||
self._audit("modify", f"BLOCKED {key}")
|
||||
return f"Error: '{key}' is protected and cannot be modified"
|
||||
if key in self.BLOCKED:
|
||||
if key in self.BLOCKED or key.startswith("__") or key in self._DENIED_ATTRS:
|
||||
self._audit("modify", f"BLOCKED {key}")
|
||||
return f"Error: '{key}' is protected and cannot be modified"
|
||||
if key in self.READONLY:
|
||||
@ -248,28 +246,43 @@ class SelfTool(Tool):
|
||||
if err:
|
||||
self._audit("modify", f"REJECTED {key}: {err}")
|
||||
return f"Error: {err}"
|
||||
# Limit total keys to prevent unbounded memory growth
|
||||
if key not in self._loop._runtime_vars and len(self._loop._runtime_vars) >= 64:
|
||||
self._audit("modify", f"REJECTED {key}: max keys (64) reached")
|
||||
return "Error: _runtime_vars is full (max 64 keys). Reset unused keys first."
|
||||
old = self._loop._runtime_vars.get(key)
|
||||
self._loop._runtime_vars[key] = value
|
||||
self._audit("modify", f"_runtime_vars.{key}: {old!r} -> {value!r}")
|
||||
return f"Set _runtime_vars.{key} = {value!r}"
|
||||
|
||||
@staticmethod
|
||||
def _validate_json_safe(value: Any, depth: int = 0) -> str | None:
|
||||
"""Validate that value is JSON-safe (no nested references to live objects)."""
|
||||
@classmethod
|
||||
def _validate_json_safe(cls, value: Any, depth: int = 0, elements: int = 0) -> str | None:
|
||||
"""Validate that value is JSON-safe (no nested references to live objects).
|
||||
|
||||
Returns an error string if validation fails, or None if the value is safe.
|
||||
``elements`` tracks the cumulative count of list items + dict entries to
|
||||
enforce a per-value size cap.
|
||||
"""
|
||||
if depth > 10:
|
||||
return "value nesting too deep (max 10 levels)"
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
elements += len(value)
|
||||
if elements > cls._MAX_VALUE_ELEMENTS:
|
||||
return f"value too large (max {cls._MAX_VALUE_ELEMENTS} total elements)"
|
||||
for i, item in enumerate(value):
|
||||
if err := SelfTool._validate_json_safe(item, depth + 1):
|
||||
if err := cls._validate_json_safe(item, depth + 1, elements):
|
||||
return f"list[{i}] contains {err}"
|
||||
return None
|
||||
if isinstance(value, dict):
|
||||
elements += len(value)
|
||||
if elements > cls._MAX_VALUE_ELEMENTS:
|
||||
return f"value too large (max {cls._MAX_VALUE_ELEMENTS} total elements)"
|
||||
for k, v in value.items():
|
||||
if not isinstance(k, str):
|
||||
return f"dict key must be str, got {type(k).__name__}"
|
||||
if err := SelfTool._validate_json_safe(v, depth + 1):
|
||||
if err := cls._validate_json_safe(v, depth + 1, elements):
|
||||
return f"dict key '{k}' contains {err}"
|
||||
return None
|
||||
return f"unsupported type {type(value).__name__}"
|
||||
@ -313,9 +326,9 @@ class SelfTool(Tool):
|
||||
if key in self.READONLY:
|
||||
return f"Error: '{key}' is read-only (already at its configured value)"
|
||||
if key in self.RESTRICTED:
|
||||
default = self._loop._config_defaults.get(key)
|
||||
if default is None:
|
||||
if key not in self._loop._config_defaults:
|
||||
return f"Error: no config default for '{key}'"
|
||||
default = self._loop._config_defaults[key]
|
||||
old = getattr(self._loop, key)
|
||||
setattr(self._loop, key, default)
|
||||
self._audit("reset", f"{key}: {old!r} -> {default!r}")
|
||||
|
||||
@ -2,12 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.self import SelfTool
|
||||
|
||||
|
||||
@ -47,6 +47,9 @@ def _make_mock_loop(**overrides):
|
||||
loop.tools.has.side_effect = lambda n: n in loop.tools.tool_names
|
||||
loop.tools.get.return_value = None
|
||||
|
||||
# Attach the real _watchdog_check method to the mock so tests exercise actual code
|
||||
loop._watchdog_check = AgentLoop._watchdog_check.__get__(loop)
|
||||
|
||||
for k, v in overrides.items():
|
||||
setattr(loop, k, v)
|
||||
|
||||
@ -94,13 +97,30 @@ class TestInspect:
|
||||
assert "task" in result
|
||||
assert "review" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_none_attribute_shows_value(self):
|
||||
"""Attributes that are legitimately None should show their value, not 'not found'."""
|
||||
tool = _make_tool()
|
||||
# web_proxy is initialized as None in the mock
|
||||
result = await tool.execute(action="inspect", key="web_proxy")
|
||||
assert "web_proxy: None" in result
|
||||
assert "not found" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inspect_unknown_key(self):
|
||||
loop = _make_mock_loop()
|
||||
# Make getattr return None for unknown keys (like a real object)
|
||||
loop.nonexistent = None
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="inspect", key="nonexistent")
|
||||
"""Use a real object (not MagicMock) so hasattr returns False for missing attrs."""
|
||||
class MinimalLoop:
|
||||
model = "test-model"
|
||||
max_iterations = 40
|
||||
context_window_tokens = 65536
|
||||
context_budget_tokens = 500
|
||||
_runtime_vars = {}
|
||||
_unregistered_tools = {}
|
||||
_config_defaults = {}
|
||||
_critical_tool_backup = {}
|
||||
loop = MinimalLoop()
|
||||
tool = SelfTool(loop=loop)
|
||||
result = await tool.execute(action="inspect", key="nonexistent_key")
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
@ -248,6 +268,24 @@ class TestModifyFree:
|
||||
result = await tool.execute(action="modify", key="data", value={"a": 1})
|
||||
assert tool._loop._runtime_vars["data"] == {"a": 1}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_rejects_when_max_keys_reached(self):
|
||||
loop = _make_mock_loop()
|
||||
loop._runtime_vars = {f"key_{i}": i for i in range(64)}
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="modify", key="overflow", value="data")
|
||||
assert "full" in result
|
||||
assert "overflow" not in loop._runtime_vars
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_allows_update_existing_key_at_max(self):
|
||||
loop = _make_mock_loop()
|
||||
loop._runtime_vars = {f"key_{i}": i for i in range(64)}
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="modify", key="key_0", value="updated")
|
||||
assert "Error" not in result
|
||||
assert loop._runtime_vars["key_0"] == "updated"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# unregister_tool / register_tool
|
||||
@ -442,6 +480,33 @@ class TestEdgeCases:
|
||||
result = await tool.execute(action="reset", key="_config_defaults")
|
||||
assert "protected" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_denied_attrs_non_dunder_blocked(self):
|
||||
"""Non-dunder entries in _DENIED_ATTRS (e.g. func_globals) must be blocked."""
|
||||
tool = _make_tool()
|
||||
for attr in ("func_globals", "func_code"):
|
||||
result = await tool.execute(action="modify", key=attr, value="evil")
|
||||
assert "protected" in result, f"{attr} should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_free_value_too_large_rejected(self):
|
||||
"""Values exceeding _MAX_VALUE_ELEMENTS should be rejected."""
|
||||
tool = _make_tool()
|
||||
big_list = list(range(2000))
|
||||
result = await tool.execute(action="modify", key="big", value=big_list)
|
||||
assert "too large" in result
|
||||
assert "big" not in tool._loop._runtime_vars
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_with_none_default_succeeds(self):
|
||||
"""Reset should work even if the config default is legitimately None."""
|
||||
loop = _make_mock_loop()
|
||||
loop._config_defaults["max_iterations"] = None
|
||||
loop.max_iterations = 80
|
||||
tool = _make_tool(loop)
|
||||
result = await tool.execute(action="reset", key="max_iterations")
|
||||
assert "Reset max_iterations = None" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# watchdog (tested via AgentLoop method, using a real loop-like object)
|
||||
@ -452,20 +517,13 @@ class TestWatchdog:
|
||||
def test_watchdog_corrects_invalid_iterations(self):
|
||||
loop = _make_mock_loop()
|
||||
loop.max_iterations = 0
|
||||
loop._critical_tool_backup = {}
|
||||
# Simulate watchdog
|
||||
defaults = loop._config_defaults
|
||||
if not (1 <= loop.max_iterations <= 100):
|
||||
loop.max_iterations = defaults["max_iterations"]
|
||||
loop._watchdog_check()
|
||||
assert loop.max_iterations == 40
|
||||
|
||||
def test_watchdog_corrects_invalid_context_window(self):
|
||||
loop = _make_mock_loop()
|
||||
loop.context_window_tokens = 100
|
||||
loop._critical_tool_backup = {}
|
||||
defaults = loop._config_defaults
|
||||
if not (4096 <= loop.context_window_tokens <= 1_000_000):
|
||||
loop.context_window_tokens = defaults["context_window_tokens"]
|
||||
loop._watchdog_check()
|
||||
assert loop.context_window_tokens == 65_536
|
||||
|
||||
def test_watchdog_restores_critical_tools(self):
|
||||
@ -474,10 +532,7 @@ class TestWatchdog:
|
||||
loop._critical_tool_backup = {"self": backup}
|
||||
loop.tools.has.return_value = False
|
||||
loop.tools.tool_names = []
|
||||
# Simulate watchdog
|
||||
for name, bk in loop._critical_tool_backup.items():
|
||||
if not loop.tools.has(name):
|
||||
loop.tools.register(copy.deepcopy(bk))
|
||||
loop._watchdog_check()
|
||||
loop.tools.register.assert_called()
|
||||
# Verify it was called with a copy, not the original
|
||||
called_arg = loop.tools.register.call_args[0][0]
|
||||
@ -487,13 +542,8 @@ class TestWatchdog:
|
||||
loop = _make_mock_loop()
|
||||
loop.max_iterations = 50
|
||||
loop.context_window_tokens = 131072
|
||||
loop._critical_tool_backup = {}
|
||||
original_max = loop.max_iterations
|
||||
original_ctx = loop.context_window_tokens
|
||||
# Simulate watchdog
|
||||
if not (1 <= loop.max_iterations <= 100):
|
||||
loop.max_iterations = loop._config_defaults["max_iterations"]
|
||||
if not (4096 <= loop.context_window_tokens <= 1_000_000):
|
||||
loop.context_window_tokens = loop._config_defaults["context_window_tokens"]
|
||||
loop._watchdog_check()
|
||||
assert loop.max_iterations == original_max
|
||||
assert loop.context_window_tokens == original_ctx
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user