mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-08 12:13:36 +00:00
refactor: extract tool hint formatting to utils/tool_hints.py
- Move _tool_hint implementation from loop.py to nanobot/utils/tool_hints.py - Keep thin delegation in AgentLoop._tool_hint for backward compat - Update test imports to test format_tool_hints directly Made-with: Cursor
This commit is contained in:
parent
3e3a7654f8
commit
82dec12f66
@ -323,117 +323,12 @@ class AgentLoop:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
return strip_think(text) or None
|
||||
|
||||
# Registry: tool_name -> (key_args, template, is_path, is_command)
|
||||
_TOOL_HINT_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
|
||||
"read_file": (["path", "file_path"], "read {}", True, False),
|
||||
"write_file": (["path", "file_path"], "write {}", True, False),
|
||||
"edit": (["file_path", "path"], "edit {}", True, False),
|
||||
"glob": (["pattern"], 'glob "{}"', False, False),
|
||||
"grep": (["pattern"], 'grep "{}"', False, False),
|
||||
"exec": (["command"], "$ {}", False, True),
|
||||
"web_search": (["query"], 'search "{}"', False, False),
|
||||
"web_fetch": (["url"], "fetch {}", True, False),
|
||||
"list_dir": (["path"], "ls {}", True, False),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hints with smart abbreviation."""
|
||||
if not tool_calls:
|
||||
return ""
|
||||
from nanobot.utils.tool_hints import format_tool_hints
|
||||
|
||||
from nanobot.utils.path import abbreviate_path
|
||||
|
||||
def _get_args(tc) -> dict:
|
||||
"""Extract args dict from tc.arguments, handling list/dict/None/empty."""
|
||||
if tc.arguments is None:
|
||||
return {}
|
||||
if isinstance(tc.arguments, list):
|
||||
return tc.arguments[0] if tc.arguments else {}
|
||||
if isinstance(tc.arguments, dict):
|
||||
return tc.arguments
|
||||
return {}
|
||||
|
||||
def _group_consecutive(calls):
|
||||
"""Group consecutive calls to the same tool: [(name, count, first), ...]."""
|
||||
groups = []
|
||||
for tc in calls:
|
||||
if groups and groups[-1][0] == tc.name:
|
||||
groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2])
|
||||
else:
|
||||
groups.append((tc.name, 1, tc))
|
||||
return groups
|
||||
|
||||
def _extract_arg(tc, key_args):
|
||||
"""Extract the first available value from preferred key names."""
|
||||
args = _get_args(tc)
|
||||
if not isinstance(args, dict):
|
||||
return None
|
||||
for key in key_args:
|
||||
val = args.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
# Fallback: first string value
|
||||
for val in args.values():
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
return None
|
||||
|
||||
def _fmt_known(tc, fmt):
|
||||
"""Format a registered tool using its template."""
|
||||
val = _extract_arg(tc, fmt[0])
|
||||
if val is None:
|
||||
return tc.name
|
||||
if fmt[2]: # is_path
|
||||
val = abbreviate_path(val)
|
||||
elif fmt[3]: # is_command
|
||||
val = val[:40] + "\u2026" if len(val) > 40 else val
|
||||
return fmt[1].format(val)
|
||||
|
||||
def _fmt_mcp(tc):
|
||||
"""Format MCP tool as server::tool."""
|
||||
name = tc.name
|
||||
# mcp_<server>__<tool> or mcp_<server>_<tool>
|
||||
if "__" in name:
|
||||
parts = name.split("__", 1)
|
||||
server = parts[0].removeprefix("mcp_")
|
||||
tool = parts[1]
|
||||
else:
|
||||
rest = name.removeprefix("mcp_")
|
||||
parts = rest.split("_", 1)
|
||||
server = parts[0] if parts else rest
|
||||
tool = parts[1] if len(parts) > 1 else ""
|
||||
if not tool:
|
||||
return name
|
||||
args = _get_args(tc)
|
||||
val = next((v for v in args.values() if isinstance(v, str) and v), None)
|
||||
if val is None:
|
||||
return f"{server}::{tool}"
|
||||
return f'{server}::{tool}("{abbreviate_path(val, 40)}")'
|
||||
|
||||
def _fmt_fallback(tc):
|
||||
"""Original formatting logic for unregistered tools."""
|
||||
args = _get_args(tc)
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{abbreviate_path(val, 40)}")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
|
||||
hints = []
|
||||
for name, count, example_tc in _group_consecutive(tool_calls):
|
||||
fmt = AgentLoop._TOOL_HINT_FORMATS.get(name)
|
||||
if fmt:
|
||||
hint = _fmt_known(example_tc, fmt)
|
||||
elif name.startswith("mcp_"):
|
||||
hint = _fmt_mcp(example_tc)
|
||||
else:
|
||||
hint = _fmt_fallback(example_tc)
|
||||
|
||||
if count > 1:
|
||||
hint = f"{hint} \u00d7 {count}"
|
||||
hints.append(hint)
|
||||
|
||||
return ", ".join(hints)
|
||||
return format_tool_hints(tool_calls)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
|
||||
119
nanobot/utils/tool_hints.py
Normal file
119
nanobot/utils/tool_hints.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""Tool hint formatting for concise, human-readable tool call display."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from nanobot.utils.path import abbreviate_path
|
||||
|
||||
# Registry: tool_name -> (key_args, template, is_path, is_command)
|
||||
_TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
|
||||
"read_file": (["path", "file_path"], "read {}", True, False),
|
||||
"write_file": (["path", "file_path"], "write {}", True, False),
|
||||
"edit": (["file_path", "path"], "edit {}", True, False),
|
||||
"glob": (["pattern"], 'glob "{}"', False, False),
|
||||
"grep": (["pattern"], 'grep "{}"', False, False),
|
||||
"exec": (["command"], "$ {}", False, True),
|
||||
"web_search": (["query"], 'search "{}"', False, False),
|
||||
"web_fetch": (["url"], "fetch {}", True, False),
|
||||
"list_dir": (["path"], "ls {}", True, False),
|
||||
}
|
||||
|
||||
|
||||
def format_tool_hints(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hints with smart abbreviation."""
|
||||
if not tool_calls:
|
||||
return ""
|
||||
|
||||
hints = []
|
||||
for name, count, example_tc in _group_consecutive(tool_calls):
|
||||
fmt = _TOOL_FORMATS.get(name)
|
||||
if fmt:
|
||||
hint = _fmt_known(example_tc, fmt)
|
||||
elif name.startswith("mcp_"):
|
||||
hint = _fmt_mcp(example_tc)
|
||||
else:
|
||||
hint = _fmt_fallback(example_tc)
|
||||
|
||||
if count > 1:
|
||||
hint = f"{hint} \u00d7 {count}"
|
||||
hints.append(hint)
|
||||
|
||||
return ", ".join(hints)
|
||||
|
||||
|
||||
def _get_args(tc) -> dict:
|
||||
"""Extract args dict from tc.arguments, handling list/dict/None/empty."""
|
||||
if tc.arguments is None:
|
||||
return {}
|
||||
if isinstance(tc.arguments, list):
|
||||
return tc.arguments[0] if tc.arguments else {}
|
||||
if isinstance(tc.arguments, dict):
|
||||
return tc.arguments
|
||||
return {}
|
||||
|
||||
|
||||
def _group_consecutive(calls: list) -> list[tuple[str, int, object]]:
|
||||
"""Group consecutive calls to the same tool: [(name, count, first), ...]."""
|
||||
groups: list[tuple[str, int, object]] = []
|
||||
for tc in calls:
|
||||
if groups and groups[-1][0] == tc.name:
|
||||
groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2])
|
||||
else:
|
||||
groups.append((tc.name, 1, tc))
|
||||
return groups
|
||||
|
||||
|
||||
def _extract_arg(tc, key_args: list[str]) -> str | None:
|
||||
"""Extract the first available value from preferred key names."""
|
||||
args = _get_args(tc)
|
||||
if not isinstance(args, dict):
|
||||
return None
|
||||
for key in key_args:
|
||||
val = args.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
for val in args.values():
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
return None
|
||||
|
||||
|
||||
def _fmt_known(tc, fmt: tuple) -> str:
|
||||
"""Format a registered tool using its template."""
|
||||
val = _extract_arg(tc, fmt[0])
|
||||
if val is None:
|
||||
return tc.name
|
||||
if fmt[2]: # is_path
|
||||
val = abbreviate_path(val)
|
||||
elif fmt[3]: # is_command
|
||||
val = val[:40] + "\u2026" if len(val) > 40 else val
|
||||
return fmt[1].format(val)
|
||||
|
||||
|
||||
def _fmt_mcp(tc) -> str:
|
||||
"""Format MCP tool as server::tool."""
|
||||
name = tc.name
|
||||
if "__" in name:
|
||||
parts = name.split("__", 1)
|
||||
server = parts[0].removeprefix("mcp_")
|
||||
tool = parts[1]
|
||||
else:
|
||||
rest = name.removeprefix("mcp_")
|
||||
parts = rest.split("_", 1)
|
||||
server = parts[0] if parts else rest
|
||||
tool = parts[1] if len(parts) > 1 else ""
|
||||
if not tool:
|
||||
return name
|
||||
args = _get_args(tc)
|
||||
val = next((v for v in args.values() if isinstance(v, str) and v), None)
|
||||
if val is None:
|
||||
return f"{server}::{tool}"
|
||||
return f'{server}::{tool}("{abbreviate_path(val, 40)}")'
|
||||
|
||||
|
||||
def _fmt_fallback(tc) -> str:
|
||||
"""Original formatting logic for unregistered tools."""
|
||||
args = _get_args(tc)
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{abbreviate_path(val, 40)}")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
@ -1,6 +1,6 @@
|
||||
"""Tests for AgentLoop._tool_hint() formatting."""
|
||||
"""Tests for tool hint formatting (nanobot.utils.tool_hints)."""
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.utils.tool_hints import format_tool_hints
|
||||
from nanobot.providers.base import ToolCallRequest
|
||||
|
||||
|
||||
@ -8,51 +8,56 @@ def _tc(name: str, args) -> ToolCallRequest:
|
||||
return ToolCallRequest(id="c1", name=name, arguments=args)
|
||||
|
||||
|
||||
def _hint(calls):
|
||||
"""Shortcut for format_tool_hints."""
|
||||
return format_tool_hints(calls)
|
||||
|
||||
|
||||
class TestToolHintKnownTools:
|
||||
"""Test registered tool types produce correct formatted output."""
|
||||
|
||||
def test_read_file_short_path(self):
|
||||
result = AgentLoop._tool_hint([_tc("read_file", {"path": "foo.txt"})])
|
||||
result = _hint([_tc("read_file", {"path": "foo.txt"})])
|
||||
assert result == 'read foo.txt'
|
||||
|
||||
def test_read_file_long_path(self):
|
||||
result = AgentLoop._tool_hint([_tc("read_file", {"path": "/home/user/.local/share/uv/tools/nanobot/agent/loop.py"})])
|
||||
result = _hint([_tc("read_file", {"path": "/home/user/.local/share/uv/tools/nanobot/agent/loop.py"})])
|
||||
assert "loop.py" in result
|
||||
assert "read " in result
|
||||
|
||||
def test_write_file_shows_path_not_content(self):
|
||||
result = AgentLoop._tool_hint([_tc("write_file", {"path": "docs/api.md", "content": "# API Reference\n\nLong content..."})])
|
||||
result = _hint([_tc("write_file", {"path": "docs/api.md", "content": "# API Reference\n\nLong content..."})])
|
||||
assert result == "write docs/api.md"
|
||||
|
||||
def test_edit_shows_path(self):
|
||||
result = AgentLoop._tool_hint([_tc("edit", {"file_path": "src/main.py", "old_string": "x", "new_string": "y"})])
|
||||
result = _hint([_tc("edit", {"file_path": "src/main.py", "old_string": "x", "new_string": "y"})])
|
||||
assert "main.py" in result
|
||||
assert "edit " in result
|
||||
|
||||
def test_glob_shows_pattern(self):
|
||||
result = AgentLoop._tool_hint([_tc("glob", {"pattern": "**/*.py", "path": "src"})])
|
||||
result = _hint([_tc("glob", {"pattern": "**/*.py", "path": "src"})])
|
||||
assert result == 'glob "**/*.py"'
|
||||
|
||||
def test_grep_shows_pattern(self):
|
||||
result = AgentLoop._tool_hint([_tc("grep", {"pattern": "TODO|FIXME", "path": "src"})])
|
||||
result = _hint([_tc("grep", {"pattern": "TODO|FIXME", "path": "src"})])
|
||||
assert result == 'grep "TODO|FIXME"'
|
||||
|
||||
def test_exec_shows_command(self):
|
||||
result = AgentLoop._tool_hint([_tc("exec", {"command": "npm install typescript"})])
|
||||
result = _hint([_tc("exec", {"command": "npm install typescript"})])
|
||||
assert result == "$ npm install typescript"
|
||||
|
||||
def test_exec_truncates_long_command(self):
|
||||
cmd = "cd /very/long/path && cat file && echo done && sleep 1 && ls -la"
|
||||
result = AgentLoop._tool_hint([_tc("exec", {"command": cmd})])
|
||||
result = _hint([_tc("exec", {"command": cmd})])
|
||||
assert result.startswith("$ ")
|
||||
assert len(result) <= 50 # reasonable limit
|
||||
|
||||
def test_web_search(self):
|
||||
result = AgentLoop._tool_hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})])
|
||||
result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})])
|
||||
assert result == 'search "Claude 4 vs GPT-4"'
|
||||
|
||||
def test_web_fetch(self):
|
||||
result = AgentLoop._tool_hint([_tc("web_fetch", {"url": "https://example.com/page"})])
|
||||
result = _hint([_tc("web_fetch", {"url": "https://example.com/page"})])
|
||||
assert result == "fetch https://example.com/page"
|
||||
|
||||
|
||||
@ -60,12 +65,12 @@ class TestToolHintMCP:
|
||||
"""Test MCP tools are abbreviated to server::tool format."""
|
||||
|
||||
def test_mcp_standard_format(self):
|
||||
result = AgentLoop._tool_hint([_tc("mcp_4_5v_mcp__analyze_image", {"imageSource": "https://img.jpg", "prompt": "describe"})])
|
||||
result = _hint([_tc("mcp_4_5v_mcp__analyze_image", {"imageSource": "https://img.jpg", "prompt": "describe"})])
|
||||
assert "4_5v" in result
|
||||
assert "analyze_image" in result
|
||||
|
||||
def test_mcp_simple_name(self):
|
||||
result = AgentLoop._tool_hint([_tc("mcp_github__create_issue", {"title": "Bug fix"})])
|
||||
result = _hint([_tc("mcp_github__create_issue", {"title": "Bug fix"})])
|
||||
assert "github" in result
|
||||
assert "create_issue" in result
|
||||
|
||||
@ -74,21 +79,21 @@ class TestToolHintFallback:
|
||||
"""Test unknown tools fall back to original behavior."""
|
||||
|
||||
def test_unknown_tool_with_string_arg(self):
|
||||
result = AgentLoop._tool_hint([_tc("custom_tool", {"data": "hello world"})])
|
||||
result = _hint([_tc("custom_tool", {"data": "hello world"})])
|
||||
assert result == 'custom_tool("hello world")'
|
||||
|
||||
def test_unknown_tool_with_long_arg_truncates(self):
|
||||
long_val = "a" * 60
|
||||
result = AgentLoop._tool_hint([_tc("custom_tool", {"data": long_val})])
|
||||
result = _hint([_tc("custom_tool", {"data": long_val})])
|
||||
assert len(result) < 80
|
||||
assert "\u2026" in result
|
||||
|
||||
def test_unknown_tool_no_string_arg(self):
|
||||
result = AgentLoop._tool_hint([_tc("custom_tool", {"count": 42})])
|
||||
result = _hint([_tc("custom_tool", {"count": 42})])
|
||||
assert result == "custom_tool"
|
||||
|
||||
def test_empty_tool_calls(self):
|
||||
result = AgentLoop._tool_hint([])
|
||||
result = _hint([])
|
||||
assert result == ""
|
||||
|
||||
|
||||
@ -97,7 +102,7 @@ class TestToolHintFolding:
|
||||
|
||||
def test_single_call_no_fold(self):
|
||||
calls = [_tc("grep", {"pattern": "*.py"})]
|
||||
result = AgentLoop._tool_hint(calls)
|
||||
result = _hint(calls)
|
||||
assert "\u00d7" not in result
|
||||
|
||||
def test_two_consecutive_same_folded(self):
|
||||
@ -105,7 +110,7 @@ class TestToolHintFolding:
|
||||
_tc("grep", {"pattern": "*.py"}),
|
||||
_tc("grep", {"pattern": "*.ts"}),
|
||||
]
|
||||
result = AgentLoop._tool_hint(calls)
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 2" in result
|
||||
|
||||
def test_three_consecutive_same_folded(self):
|
||||
@ -114,7 +119,7 @@ class TestToolHintFolding:
|
||||
_tc("read_file", {"path": "b.py"}),
|
||||
_tc("read_file", {"path": "c.py"}),
|
||||
]
|
||||
result = AgentLoop._tool_hint(calls)
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 3" in result
|
||||
|
||||
def test_different_tools_not_folded(self):
|
||||
@ -122,7 +127,7 @@ class TestToolHintFolding:
|
||||
_tc("grep", {"pattern": "TODO"}),
|
||||
_tc("read_file", {"path": "a.py"}),
|
||||
]
|
||||
result = AgentLoop._tool_hint(calls)
|
||||
result = _hint(calls)
|
||||
assert "\u00d7" not in result
|
||||
|
||||
def test_interleaved_same_tools_not_folded(self):
|
||||
@ -131,7 +136,7 @@ class TestToolHintFolding:
|
||||
_tc("read_file", {"path": "f.py"}),
|
||||
_tc("grep", {"pattern": "b"}),
|
||||
]
|
||||
result = AgentLoop._tool_hint(calls)
|
||||
result = _hint(calls)
|
||||
assert "\u00d7" not in result
|
||||
|
||||
|
||||
@ -143,7 +148,7 @@ class TestToolHintMultipleCalls:
|
||||
_tc("grep", {"pattern": "TODO"}),
|
||||
_tc("read_file", {"path": "main.py"}),
|
||||
]
|
||||
result = AgentLoop._tool_hint(calls)
|
||||
result = _hint(calls)
|
||||
assert 'grep "TODO"' in result
|
||||
assert "read main.py" in result
|
||||
assert ", " in result
|
||||
@ -154,27 +159,27 @@ class TestToolHintEdgeCases:
|
||||
|
||||
def test_known_tool_empty_list_args(self):
|
||||
"""C1/G1: Empty list arguments should not crash."""
|
||||
result = AgentLoop._tool_hint([_tc("read_file", [])])
|
||||
result = _hint([_tc("read_file", [])])
|
||||
assert result == "read_file"
|
||||
|
||||
def test_known_tool_none_args(self):
|
||||
"""G2: None arguments should not crash."""
|
||||
result = AgentLoop._tool_hint([_tc("read_file", None)])
|
||||
result = _hint([_tc("read_file", None)])
|
||||
assert result == "read_file"
|
||||
|
||||
def test_fallback_empty_list_args(self):
|
||||
"""C1: Empty list args in fallback should not crash."""
|
||||
result = AgentLoop._tool_hint([_tc("custom_tool", [])])
|
||||
result = _hint([_tc("custom_tool", [])])
|
||||
assert result == "custom_tool"
|
||||
|
||||
def test_fallback_none_args(self):
|
||||
"""G2: None args in fallback should not crash."""
|
||||
result = AgentLoop._tool_hint([_tc("custom_tool", None)])
|
||||
result = _hint([_tc("custom_tool", None)])
|
||||
assert result == "custom_tool"
|
||||
|
||||
def test_list_dir_registered(self):
|
||||
"""S2: list_dir should use 'ls' format."""
|
||||
result = AgentLoop._tool_hint([_tc("list_dir", {"path": "/tmp"})])
|
||||
result = _hint([_tc("list_dir", {"path": "/tmp"})])
|
||||
assert result == "ls /tmp"
|
||||
|
||||
|
||||
@ -190,7 +195,7 @@ class TestToolHintMixedFolding:
|
||||
_tc("grep", {"pattern": "y"}),
|
||||
_tc("read_file", {"path": "c.py"}),
|
||||
]
|
||||
result = AgentLoop._tool_hint(calls)
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 2" in result
|
||||
# Should have 3 groups: read×2, grep×2, read
|
||||
parts = result.split(", ")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user