mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-01 07:15:52 +00:00
fix(tools): isolate decorated tool schemas and add regression tests
This commit is contained in:
parent
e7798a28ee
commit
05fe7d4fb1
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
_ToolT = TypeVar("_ToolT", bound="Tool")
|
_ToolT = TypeVar("_ToolT", bound="Tool")
|
||||||
@ -246,7 +247,7 @@ def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_To
|
|||||||
"""Class decorator: attach JSON Schema and inject a concrete ``parameters`` property.
|
"""Class decorator: attach JSON Schema and inject a concrete ``parameters`` property.
|
||||||
|
|
||||||
Use on ``Tool`` subclasses instead of writing ``@property def parameters``. The
|
Use on ``Tool`` subclasses instead of writing ``@property def parameters``. The
|
||||||
schema is stored on the class (shallow-copied) as ``_tool_parameters_schema``.
|
schema is stored on the class and returned as a fresh copy on each access.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@ -260,13 +261,13 @@ def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_To
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(cls: type[_ToolT]) -> type[_ToolT]:
|
def decorator(cls: type[_ToolT]) -> type[_ToolT]:
|
||||||
frozen = dict(schema)
|
frozen = deepcopy(schema)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self: Any) -> dict[str, Any]:
|
def parameters(self: Any) -> dict[str, Any]:
|
||||||
return frozen
|
return deepcopy(frozen)
|
||||||
|
|
||||||
cls._tool_parameters_schema = frozen
|
cls._tool_parameters_schema = deepcopy(frozen)
|
||||||
cls.parameters = parameters # type: ignore[assignment]
|
cls.parameters = parameters # type: ignore[assignment]
|
||||||
|
|
||||||
abstract = getattr(cls, "__abstractmethods__", None)
|
abstract = getattr(cls, "__abstractmethods__", None)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from nanobot.agent.tools import (
|
|||||||
ObjectSchema,
|
ObjectSchema,
|
||||||
Schema,
|
Schema,
|
||||||
StringSchema,
|
StringSchema,
|
||||||
|
tool_parameters,
|
||||||
tool_parameters_schema,
|
tool_parameters_schema,
|
||||||
)
|
)
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
@ -49,6 +50,26 @@ class SampleTool(Tool):
|
|||||||
return "ok"
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
@tool_parameters(
|
||||||
|
tool_parameters_schema(
|
||||||
|
query=StringSchema(min_length=2),
|
||||||
|
count=IntegerSchema(2, minimum=1, maximum=10),
|
||||||
|
required=["query", "count"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
class DecoratedSampleTool(Tool):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "decorated_sample"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "decorated sample tool"
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
return f"ok:{kwargs['count']}"
|
||||||
|
|
||||||
|
|
||||||
def test_schema_validate_value_matches_tool_validate_params() -> None:
|
def test_schema_validate_value_matches_tool_validate_params() -> None:
|
||||||
"""ObjectSchema.validate_value 与 validate_json_schema_value、Tool.validate_params 一致。"""
|
"""ObjectSchema.validate_value 与 validate_json_schema_value、Tool.validate_params 一致。"""
|
||||||
root = tool_parameters_schema(
|
root = tool_parameters_schema(
|
||||||
@ -101,6 +122,31 @@ def test_schema_classes_equivalent_to_sample_tool_parameters() -> None:
|
|||||||
assert built == SampleTool().parameters
|
assert built == SampleTool().parameters
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_parameters_returns_fresh_copy_per_access() -> None:
|
||||||
|
tool = DecoratedSampleTool()
|
||||||
|
|
||||||
|
first = tool.parameters
|
||||||
|
second = tool.parameters
|
||||||
|
|
||||||
|
assert first == second
|
||||||
|
assert first is not second
|
||||||
|
assert first["properties"] is not second["properties"]
|
||||||
|
|
||||||
|
first["properties"]["query"]["minLength"] = 99
|
||||||
|
assert tool.parameters["properties"]["query"]["minLength"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_registry_executes_decorated_tool_end_to_end() -> None:
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(DecoratedSampleTool())
|
||||||
|
|
||||||
|
ok = await reg.execute("decorated_sample", {"query": "hello", "count": "3"})
|
||||||
|
assert ok == "ok:3"
|
||||||
|
|
||||||
|
err = await reg.execute("decorated_sample", {"query": "h", "count": 3})
|
||||||
|
assert "Invalid parameters" in err
|
||||||
|
|
||||||
|
|
||||||
def test_validate_params_missing_required() -> None:
|
def test_validate_params_missing_required() -> None:
|
||||||
tool = SampleTool()
|
tool = SampleTool()
|
||||||
errors = tool.validate_params({"query": "hi"})
|
errors = tool.validate_params({"query": "hi"})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user