fix(tools): isolate decorated tool schemas and add regression tests

This commit is contained in:
Xubin Ren 2026-04-04 11:53:42 +00:00 committed by Xubin Ren
parent e7798a28ee
commit 05fe7d4fb1
2 changed files with 51 additions and 4 deletions

View File

@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import Any, TypeVar
_ToolT = TypeVar("_ToolT", bound="Tool")
@ -246,7 +247,7 @@ def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_To
"""Class decorator: attach JSON Schema and inject a concrete ``parameters`` property.
Use on ``Tool`` subclasses instead of writing ``@property def parameters``. The
schema is stored on the class (shallow-copied) as ``_tool_parameters_schema``.
schema is stored on the class and returned as a fresh copy on each access.
Example::
@ -260,13 +261,13 @@ def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_To
"""
def decorator(cls: type[_ToolT]) -> type[_ToolT]:
frozen = dict(schema)
frozen = deepcopy(schema)
@property
def parameters(self: Any) -> dict[str, Any]:
return frozen
return deepcopy(frozen)
cls._tool_parameters_schema = frozen
cls._tool_parameters_schema = deepcopy(frozen)
cls.parameters = parameters # type: ignore[assignment]
abstract = getattr(cls, "__abstractmethods__", None)

View File

@ -6,6 +6,7 @@ from nanobot.agent.tools import (
ObjectSchema,
Schema,
StringSchema,
tool_parameters,
tool_parameters_schema,
)
from nanobot.agent.tools.base import Tool
@ -49,6 +50,26 @@ class SampleTool(Tool):
return "ok"
@tool_parameters(
tool_parameters_schema(
query=StringSchema(min_length=2),
count=IntegerSchema(2, minimum=1, maximum=10),
required=["query", "count"],
)
)
class DecoratedSampleTool(Tool):
@property
def name(self) -> str:
return "decorated_sample"
@property
def description(self) -> str:
return "decorated sample tool"
async def execute(self, **kwargs: Any) -> str:
return f"ok:{kwargs['count']}"
def test_schema_validate_value_matches_tool_validate_params() -> None:
"""ObjectSchema.validate_value 与 validate_json_schema_value、Tool.validate_params 一致。"""
root = tool_parameters_schema(
@ -101,6 +122,31 @@ def test_schema_classes_equivalent_to_sample_tool_parameters() -> None:
assert built == SampleTool().parameters
def test_tool_parameters_returns_fresh_copy_per_access() -> None:
tool = DecoratedSampleTool()
first = tool.parameters
second = tool.parameters
assert first == second
assert first is not second
assert first["properties"] is not second["properties"]
first["properties"]["query"]["minLength"] = 99
assert tool.parameters["properties"]["query"]["minLength"] == 2
async def test_registry_executes_decorated_tool_end_to_end() -> None:
reg = ToolRegistry()
reg.register(DecoratedSampleTool())
ok = await reg.execute("decorated_sample", {"query": "hello", "count": "3"})
assert ok == "ok:3"
err = await reg.execute("decorated_sample", {"query": "h", "count": 3})
assert "Invalid parameters" in err
def test_validate_params_missing_required() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "hi"})