diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index 5e19e5c40..9e63620dd 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable +from copy import deepcopy from typing import Any, TypeVar _ToolT = TypeVar("_ToolT", bound="Tool") @@ -246,7 +247,7 @@ def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_To """Class decorator: attach JSON Schema and inject a concrete ``parameters`` property. Use on ``Tool`` subclasses instead of writing ``@property def parameters``. The - schema is stored on the class (shallow-copied) as ``_tool_parameters_schema``. + schema is stored on the class and returned as a fresh copy on each access. Example:: @@ -260,13 +261,13 @@ def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_To """ def decorator(cls: type[_ToolT]) -> type[_ToolT]: - frozen = dict(schema) + frozen = deepcopy(schema) @property def parameters(self: Any) -> dict[str, Any]: - return frozen + return deepcopy(frozen) - cls._tool_parameters_schema = frozen + cls._tool_parameters_schema = deepcopy(frozen) cls.parameters = parameters # type: ignore[assignment] abstract = getattr(cls, "__abstractmethods__", None) diff --git a/tests/tools/test_tool_validation.py b/tests/tools/test_tool_validation.py index b1d56a439..e56f93185 100644 --- a/tests/tools/test_tool_validation.py +++ b/tests/tools/test_tool_validation.py @@ -6,6 +6,7 @@ from nanobot.agent.tools import ( ObjectSchema, Schema, StringSchema, + tool_parameters, tool_parameters_schema, ) from nanobot.agent.tools.base import Tool @@ -49,6 +50,26 @@ class SampleTool(Tool): return "ok" +@tool_parameters( + tool_parameters_schema( + query=StringSchema(min_length=2), + count=IntegerSchema(2, minimum=1, maximum=10), + required=["query", "count"], + ) +) +class DecoratedSampleTool(Tool): + @property + def name(self) -> str: + return "decorated_sample" + + @property + def description(self) -> str: + return "decorated sample tool" + + async def execute(self, **kwargs: Any) -> str: + return f"ok:{kwargs['count']}" + + def test_schema_validate_value_matches_tool_validate_params() -> None: """ObjectSchema.validate_value 与 validate_json_schema_value、Tool.validate_params 一致。""" root = tool_parameters_schema( @@ -101,6 +122,31 @@ def test_schema_classes_equivalent_to_sample_tool_parameters() -> None: assert built == SampleTool().parameters +def test_tool_parameters_returns_fresh_copy_per_access() -> None: + tool = DecoratedSampleTool() + + first = tool.parameters + second = tool.parameters + + assert first == second + assert first is not second + assert first["properties"] is not second["properties"] + + first["properties"]["query"]["minLength"] = 99 + assert tool.parameters["properties"]["query"]["minLength"] == 2 + + +async def test_registry_executes_decorated_tool_end_to_end() -> None: + reg = ToolRegistry() + reg.register(DecoratedSampleTool()) + + ok = await reg.execute("decorated_sample", {"query": "hello", "count": "3"}) + assert ok == "ok:3" + + err = await reg.execute("decorated_sample", {"query": "h", "count": 3}) + assert "Invalid parameters" in err + + def test_validate_params_missing_required() -> None: tool = SampleTool() errors = tool.validate_params({"query": "hi"})