mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-06 03:03:36 +00:00
Refactor Tool methods and type handling; introduce JSON Schema support for tool parameters (schema module, validation tests). Made-with: Cursor
279 lines
10 KiB
Python
279 lines
10 KiB
Python
"""Base class for agent tools."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from typing import Any, TypeVar
|
|
|
|
_ToolT = TypeVar("_ToolT", bound="Tool")
|
|
|
|
# Matches :meth:`Tool._cast_value` / :meth:`Schema.validate_json_schema_value` behavior
|
|
_JSON_TYPE_MAP: dict[str, type | tuple[type, ...]] = {
|
|
"string": str,
|
|
"integer": int,
|
|
"number": (int, float),
|
|
"boolean": bool,
|
|
"array": list,
|
|
"object": dict,
|
|
}
|
|
|
|
|
|
class Schema(ABC):
|
|
"""Abstract base for JSON Schema fragments describing tool parameters.
|
|
|
|
Concrete types live in :mod:`nanobot.agent.tools.schema`; all implement
|
|
:meth:`to_json_schema` and :meth:`validate_value`. Class methods
|
|
:meth:`validate_json_schema_value` and :meth:`fragment` are the shared validation and normalization entry points.
|
|
"""
|
|
|
|
@staticmethod
|
|
def resolve_json_schema_type(t: Any) -> str | None:
|
|
"""Resolve the non-null type name from JSON Schema ``type`` (e.g. ``['string','null']`` -> ``'string'``)."""
|
|
if isinstance(t, list):
|
|
return next((x for x in t if x != "null"), None)
|
|
return t # type: ignore[return-value]
|
|
|
|
@staticmethod
|
|
def subpath(path: str, key: str) -> str:
|
|
return f"{path}.{key}" if path else key
|
|
|
|
@staticmethod
|
|
def validate_json_schema_value(val: Any, schema: dict[str, Any], path: str = "") -> list[str]:
|
|
"""Validate ``val`` against a JSON Schema fragment; returns error messages (empty means valid).
|
|
|
|
Used by :class:`Tool` and each concrete Schema's :meth:`validate_value`.
|
|
"""
|
|
raw_type = schema.get("type")
|
|
nullable = (isinstance(raw_type, list) and "null" in raw_type) or schema.get("nullable", False)
|
|
t = Schema.resolve_json_schema_type(raw_type)
|
|
label = path or "parameter"
|
|
|
|
if nullable and val is None:
|
|
return []
|
|
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
|
|
return [f"{label} should be integer"]
|
|
if t == "number" and (
|
|
not isinstance(val, _JSON_TYPE_MAP["number"]) or isinstance(val, bool)
|
|
):
|
|
return [f"{label} should be number"]
|
|
if t in _JSON_TYPE_MAP and t not in ("integer", "number") and not isinstance(val, _JSON_TYPE_MAP[t]):
|
|
return [f"{label} should be {t}"]
|
|
|
|
errors: list[str] = []
|
|
if "enum" in schema and val not in schema["enum"]:
|
|
errors.append(f"{label} must be one of {schema['enum']}")
|
|
if t in ("integer", "number"):
|
|
if "minimum" in schema and val < schema["minimum"]:
|
|
errors.append(f"{label} must be >= {schema['minimum']}")
|
|
if "maximum" in schema and val > schema["maximum"]:
|
|
errors.append(f"{label} must be <= {schema['maximum']}")
|
|
if t == "string":
|
|
if "minLength" in schema and len(val) < schema["minLength"]:
|
|
errors.append(f"{label} must be at least {schema['minLength']} chars")
|
|
if "maxLength" in schema and len(val) > schema["maxLength"]:
|
|
errors.append(f"{label} must be at most {schema['maxLength']} chars")
|
|
if t == "object":
|
|
props = schema.get("properties", {})
|
|
for k in schema.get("required", []):
|
|
if k not in val:
|
|
errors.append(f"missing required {Schema.subpath(path, k)}")
|
|
for k, v in val.items():
|
|
if k in props:
|
|
errors.extend(Schema.validate_json_schema_value(v, props[k], Schema.subpath(path, k)))
|
|
if t == "array":
|
|
if "minItems" in schema and len(val) < schema["minItems"]:
|
|
errors.append(f"{label} must have at least {schema['minItems']} items")
|
|
if "maxItems" in schema and len(val) > schema["maxItems"]:
|
|
errors.append(f"{label} must be at most {schema['maxItems']} items")
|
|
if "items" in schema:
|
|
prefix = f"{path}[{{}}]" if path else "[{}]"
|
|
for i, item in enumerate(val):
|
|
errors.extend(
|
|
Schema.validate_json_schema_value(item, schema["items"], prefix.format(i))
|
|
)
|
|
return errors
|
|
|
|
@staticmethod
|
|
def fragment(value: Any) -> dict[str, Any]:
|
|
"""Normalize a Schema instance or an existing JSON Schema dict to a fragment dict."""
|
|
# Try to_json_schema first: Schema instances must be distinguished from dicts that are already JSON Schema
|
|
to_js = getattr(value, "to_json_schema", None)
|
|
if callable(to_js):
|
|
return to_js()
|
|
if isinstance(value, dict):
|
|
return value
|
|
raise TypeError(f"Expected schema object or dict, got {type(value).__name__}")
|
|
|
|
@abstractmethod
|
|
def to_json_schema(self) -> dict[str, Any]:
|
|
"""Return a fragment dict compatible with :meth:`validate_json_schema_value`."""
|
|
...
|
|
|
|
def validate_value(self, value: Any, path: str = "") -> list[str]:
|
|
"""Validate a single value; returns error messages (empty means pass). Subclasses may override for extra rules."""
|
|
return Schema.validate_json_schema_value(value, self.to_json_schema(), path)
|
|
|
|
|
|
class Tool(ABC):
|
|
"""Agent capability: read files, run commands, etc."""
|
|
|
|
_TYPE_MAP = {
|
|
"string": str,
|
|
"integer": int,
|
|
"number": (int, float),
|
|
"boolean": bool,
|
|
"array": list,
|
|
"object": dict,
|
|
}
|
|
_BOOL_TRUE = frozenset(("true", "1", "yes"))
|
|
_BOOL_FALSE = frozenset(("false", "0", "no"))
|
|
|
|
@staticmethod
|
|
def _resolve_type(t: Any) -> str | None:
|
|
"""Pick first non-null type from JSON Schema unions like ``['string','null']``."""
|
|
return Schema.resolve_json_schema_type(t)
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self) -> str:
|
|
"""Tool name used in function calls."""
|
|
...
|
|
|
|
@property
|
|
@abstractmethod
|
|
def description(self) -> str:
|
|
"""Description of what the tool does."""
|
|
...
|
|
|
|
@property
|
|
@abstractmethod
|
|
def parameters(self) -> dict[str, Any]:
|
|
"""JSON Schema for tool parameters."""
|
|
...
|
|
|
|
@property
|
|
def read_only(self) -> bool:
|
|
"""Whether this tool is side-effect free and safe to parallelize."""
|
|
return False
|
|
|
|
@property
|
|
def concurrency_safe(self) -> bool:
|
|
"""Whether this tool can run alongside other concurrency-safe tools."""
|
|
return self.read_only and not self.exclusive
|
|
|
|
@property
|
|
def exclusive(self) -> bool:
|
|
"""Whether this tool should run alone even if concurrency is enabled."""
|
|
return False
|
|
|
|
@abstractmethod
|
|
async def execute(self, **kwargs: Any) -> Any:
|
|
"""Run the tool; returns a string or list of content blocks."""
|
|
...
|
|
|
|
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
|
|
if not isinstance(obj, dict):
|
|
return obj
|
|
props = schema.get("properties", {})
|
|
return {k: self._cast_value(v, props[k]) if k in props else v for k, v in obj.items()}
|
|
|
|
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
|
"""Apply safe schema-driven casts before validation."""
|
|
schema = self.parameters or {}
|
|
if schema.get("type", "object") != "object":
|
|
return params
|
|
return self._cast_object(params, schema)
|
|
|
|
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
|
|
t = self._resolve_type(schema.get("type"))
|
|
|
|
if t == "boolean" and isinstance(val, bool):
|
|
return val
|
|
if t == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
|
return val
|
|
if t in self._TYPE_MAP and t not in ("boolean", "integer", "array", "object"):
|
|
expected = self._TYPE_MAP[t]
|
|
if isinstance(val, expected):
|
|
return val
|
|
|
|
if isinstance(val, str) and t in ("integer", "number"):
|
|
try:
|
|
return int(val) if t == "integer" else float(val)
|
|
except ValueError:
|
|
return val
|
|
|
|
if t == "string":
|
|
return val if val is None else str(val)
|
|
|
|
if t == "boolean" and isinstance(val, str):
|
|
low = val.lower()
|
|
if low in self._BOOL_TRUE:
|
|
return True
|
|
if low in self._BOOL_FALSE:
|
|
return False
|
|
return val
|
|
|
|
if t == "array" and isinstance(val, list):
|
|
items = schema.get("items")
|
|
return [self._cast_value(x, items) for x in val] if items else val
|
|
|
|
if t == "object" and isinstance(val, dict):
|
|
return self._cast_object(val, schema)
|
|
|
|
return val
|
|
|
|
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
|
"""Validate against JSON schema; empty list means valid."""
|
|
if not isinstance(params, dict):
|
|
return [f"parameters must be an object, got {type(params).__name__}"]
|
|
schema = self.parameters or {}
|
|
if schema.get("type", "object") != "object":
|
|
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
|
|
return Schema.validate_json_schema_value(params, {**schema, "type": "object"}, "")
|
|
|
|
def to_schema(self) -> dict[str, Any]:
|
|
"""OpenAI function schema."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"parameters": self.parameters,
|
|
},
|
|
}
|
|
|
|
|
|
def tool_parameters(schema: dict[str, Any]) -> Callable[[type[_ToolT]], type[_ToolT]]:
|
|
"""Class decorator: attach JSON Schema and inject a concrete ``parameters`` property.
|
|
|
|
Use on ``Tool`` subclasses instead of writing ``@property def parameters``. The
|
|
schema is stored on the class (shallow-copied) as ``_tool_parameters_schema``.
|
|
|
|
Example::
|
|
|
|
@tool_parameters({
|
|
"type": "object",
|
|
"properties": {"path": {"type": "string"}},
|
|
"required": ["path"],
|
|
})
|
|
class ReadFileTool(Tool):
|
|
...
|
|
"""
|
|
|
|
def decorator(cls: type[_ToolT]) -> type[_ToolT]:
|
|
frozen = dict(schema)
|
|
|
|
@property
|
|
def parameters(self: Any) -> dict[str, Any]:
|
|
return frozen
|
|
|
|
cls._tool_parameters_schema = frozen
|
|
cls.parameters = parameters # type: ignore[assignment]
|
|
|
|
abstract = getattr(cls, "__abstractmethods__", None)
|
|
if abstract is not None and "parameters" in abstract:
|
|
cls.__abstractmethods__ = frozenset(abstract - {"parameters"}) # type: ignore[misc]
|
|
|
|
return cls
|
|
|
|
return decorator
|