mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-11 21:53:37 +00:00
111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
"""Tool registry for dynamic tool management."""
|
|
|
|
from typing import Any
|
|
|
|
from nanobot.agent.tools.base import Tool
|
|
|
|
|
|
class ToolRegistry:
|
|
"""
|
|
Registry for agent tools.
|
|
|
|
Allows dynamic registration and execution of tools.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._tools: dict[str, Tool] = {}
|
|
|
|
def register(self, tool: Tool) -> None:
|
|
"""Register a tool."""
|
|
self._tools[tool.name] = tool
|
|
|
|
def unregister(self, name: str) -> None:
|
|
"""Unregister a tool by name."""
|
|
self._tools.pop(name, None)
|
|
|
|
def get(self, name: str) -> Tool | None:
|
|
"""Get a tool by name."""
|
|
return self._tools.get(name)
|
|
|
|
def has(self, name: str) -> bool:
|
|
"""Check if a tool is registered."""
|
|
return name in self._tools
|
|
|
|
@staticmethod
|
|
def _schema_name(schema: dict[str, Any]) -> str:
|
|
"""Extract a normalized tool name from either OpenAI or flat schemas."""
|
|
fn = schema.get("function")
|
|
if isinstance(fn, dict):
|
|
name = fn.get("name")
|
|
if isinstance(name, str):
|
|
return name
|
|
name = schema.get("name")
|
|
return name if isinstance(name, str) else ""
|
|
|
|
def get_definitions(self) -> list[dict[str, Any]]:
|
|
"""Get tool definitions with stable ordering for cache-friendly prompts.
|
|
|
|
Built-in tools are sorted first as a stable prefix, then MCP tools are
|
|
sorted and appended.
|
|
"""
|
|
definitions = [tool.to_schema() for tool in self._tools.values()]
|
|
builtins: list[dict[str, Any]] = []
|
|
mcp_tools: list[dict[str, Any]] = []
|
|
for schema in definitions:
|
|
name = self._schema_name(schema)
|
|
if name.startswith("mcp_"):
|
|
mcp_tools.append(schema)
|
|
else:
|
|
builtins.append(schema)
|
|
|
|
builtins.sort(key=self._schema_name)
|
|
mcp_tools.sort(key=self._schema_name)
|
|
return builtins + mcp_tools
|
|
|
|
def prepare_call(
|
|
self,
|
|
name: str,
|
|
params: dict[str, Any],
|
|
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
|
"""Resolve, cast, and validate one tool call."""
|
|
tool = self._tools.get(name)
|
|
if not tool:
|
|
return None, params, (
|
|
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
|
)
|
|
|
|
cast_params = tool.cast_params(params)
|
|
errors = tool.validate_params(cast_params)
|
|
if errors:
|
|
return tool, cast_params, (
|
|
f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
|
|
)
|
|
return tool, cast_params, None
|
|
|
|
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
|
"""Execute a tool by name with given parameters."""
|
|
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
|
tool, params, error = self.prepare_call(name, params)
|
|
if error:
|
|
return error + _HINT
|
|
|
|
try:
|
|
assert tool is not None # guarded by prepare_call()
|
|
result = await tool.execute(**params)
|
|
if isinstance(result, str) and result.startswith("Error"):
|
|
return result + _HINT
|
|
return result
|
|
except Exception as e:
|
|
return f"Error executing {name}: {str(e)}" + _HINT
|
|
|
|
@property
|
|
def tool_names(self) -> list[str]:
|
|
"""Get list of registered tool names."""
|
|
return list(self._tools.keys())
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._tools)
|
|
|
|
def __contains__(self, name: str) -> bool:
|
|
return name in self._tools
|