fix(tools): isolate decorated tool schemas and add regression tests

This commit is contained in:
Xubin Ren 2026-04-04 11:53:42 +00:00 committed by Xubin Ren
parent e7798a28ee
commit 05fe7d4fb1
2 changed files with 51 additions and 4 deletions

View File

@ -2,6 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy
from typing import Any, TypeVar from typing import Any, TypeVar
_ToolT = TypeVar("_ToolT", bound="Tool") _ToolT = TypeVar("_ToolT", bound="Tool")
@ -246,7 +247,7 @@ def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_To
"""Class decorator: attach JSON Schema and inject a concrete ``parameters`` property. """Class decorator: attach JSON Schema and inject a concrete ``parameters`` property.
Use on ``Tool`` subclasses instead of writing ``@property def parameters``. The Use on ``Tool`` subclasses instead of writing ``@property def parameters``. The
schema is stored on the class (shallow-copied) as ``_tool_parameters_schema``. schema is stored on the class and returned as a fresh copy on each access.
Example:: Example::
@ -260,13 +261,13 @@ def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_To
""" """
def decorator(cls: type[_ToolT]) -> type[_ToolT]: def decorator(cls: type[_ToolT]) -> type[_ToolT]:
frozen = dict(schema) frozen = deepcopy(schema)
@property @property
def parameters(self: Any) -> dict[str, Any]: def parameters(self: Any) -> dict[str, Any]:
return frozen return deepcopy(frozen)
cls._tool_parameters_schema = frozen cls._tool_parameters_schema = deepcopy(frozen)
cls.parameters = parameters # type: ignore[assignment] cls.parameters = parameters # type: ignore[assignment]
abstract = getattr(cls, "__abstractmethods__", None) abstract = getattr(cls, "__abstractmethods__", None)

View File

@ -6,6 +6,7 @@ from nanobot.agent.tools import (
ObjectSchema, ObjectSchema,
Schema, Schema,
StringSchema, StringSchema,
tool_parameters,
tool_parameters_schema, tool_parameters_schema,
) )
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@ -49,6 +50,26 @@ class SampleTool(Tool):
return "ok" return "ok"
@tool_parameters(
tool_parameters_schema(
query=StringSchema(min_length=2),
count=IntegerSchema(2, minimum=1, maximum=10),
required=["query", "count"],
)
)
class DecoratedSampleTool(Tool):
@property
def name(self) -> str:
return "decorated_sample"
@property
def description(self) -> str:
return "decorated sample tool"
async def execute(self, **kwargs: Any) -> str:
return f"ok:{kwargs['count']}"
def test_schema_validate_value_matches_tool_validate_params() -> None: def test_schema_validate_value_matches_tool_validate_params() -> None:
"""ObjectSchema.validate_value 与 validate_json_schema_value、Tool.validate_params 一致。""" """ObjectSchema.validate_value 与 validate_json_schema_value、Tool.validate_params 一致。"""
root = tool_parameters_schema( root = tool_parameters_schema(
@ -101,6 +122,31 @@ def test_schema_classes_equivalent_to_sample_tool_parameters() -> None:
assert built == SampleTool().parameters assert built == SampleTool().parameters
def test_tool_parameters_returns_fresh_copy_per_access() -> None:
tool = DecoratedSampleTool()
first = tool.parameters
second = tool.parameters
assert first == second
assert first is not second
assert first["properties"] is not second["properties"]
first["properties"]["query"]["minLength"] = 99
assert tool.parameters["properties"]["query"]["minLength"] == 2
async def test_registry_executes_decorated_tool_end_to_end() -> None:
reg = ToolRegistry()
reg.register(DecoratedSampleTool())
ok = await reg.execute("decorated_sample", {"query": "hello", "count": "3"})
assert ok == "ok:3"
err = await reg.execute("decorated_sample", {"query": "h", "count": 3})
assert "Invalid parameters" in err
def test_validate_params_missing_required() -> None: def test_validate_params_missing_required() -> None:
tool = SampleTool() tool = SampleTool()
errors = tool.validate_params({"query": "hi"}) errors = tool.validate_params({"query": "hi"})