diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index e94999a57..e58156758 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -323,6 +323,19 @@ class AgentLoop: from nanobot.utils.helpers import strip_think return strip_think(text) or None + # Registry: tool_name -> (key_args, template, is_path, is_command) + _TOOL_HINT_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = { + "read_file": (["path", "file_path"], "read {}", True, False), + "write_file": (["path", "file_path"], "write {}", True, False), + "edit": (["file_path", "path"], "edit {}", True, False), + "glob": (["pattern"], 'glob "{}"', False, False), + "grep": (["pattern"], 'grep "{}"', False, False), + "exec": (["command"], "$ {}", False, True), + "web_search": (["query"], 'search "{}"', False, False), + "web_fetch": (["url"], "fetch {}", True, False), + "list_dir": (["path"], "ls {}", True, False), + } + @staticmethod def _tool_hint(tool_calls: list) -> str: """Format tool calls as concise hints with smart abbreviation.""" @@ -331,6 +344,16 @@ class AgentLoop: from nanobot.utils.path import abbreviate_path + def _get_args(tc) -> dict: + """Extract args dict from tc.arguments, handling list/dict/None/empty.""" + if tc.arguments is None: + return {} + if isinstance(tc.arguments, list): + return tc.arguments[0] if tc.arguments else {} + if isinstance(tc.arguments, dict): + return tc.arguments + return {} + def _group_consecutive(calls): """Group consecutive calls to the same tool: [(name, count, first), ...].""" groups = [] @@ -343,7 +366,7 @@ class AgentLoop: def _extract_arg(tc, key_args): """Extract the first available value from preferred key names.""" - args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {} + args = _get_args(tc) if not isinstance(args, dict): return None for key in key_args: @@ -365,10 +388,7 @@ class AgentLoop: val = abbreviate_path(val) elif fmt[3]: # is_command val = val[:40] + "\u2026" if len(val) > 40 else val - template = fmt[1] - if '"{}"' in template: - return template.format(val) - return template.format(val) + return fmt[1].format(val) def _fmt_mcp(tc): """Format MCP tool as server::tool.""" @@ -385,7 +405,7 @@ class AgentLoop: tool = parts[1] if len(parts) > 1 else "" if not tool: return name - args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {} + args = _get_args(tc) val = next((v for v in args.values() if isinstance(v, str) and v), None) if val is None: return f"{server}::{tool}" @@ -393,27 +413,15 @@ class AgentLoop: def _fmt_fallback(tc): """Original formatting logic for unregistered tools.""" - args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {} + args = _get_args(tc) val = next(iter(args.values()), None) if isinstance(args, dict) else None if not isinstance(val, str): return tc.name return f'{tc.name}("{abbreviate_path(val, 40)}")' if len(val) > 40 else f'{tc.name}("{val}")' - # Registry: tool_name -> (key_args, template, is_path, is_command) - FORMATS: dict[str, tuple[list[str], str, bool, bool]] = { - "read_file": (["path", "file_path"], "read {}", True, False), - "write_file": (["path", "file_path"], "write {}", True, False), - "edit": (["file_path", "path"], "edit {}", True, False), - "glob": (["pattern"], 'glob "{}"', False, False), - "grep": (["pattern"], 'grep "{}"', False, False), - "exec": (["command"], "$ {}", False, True), - "web_search": (["query"], 'search "{}"', False, False), - "web_fetch": (["url"], "fetch {}", True, False), - } - hints = [] for name, count, example_tc in _group_consecutive(tool_calls): - fmt = FORMATS.get(name) + fmt = AgentLoop._TOOL_HINT_FORMATS.get(name) if fmt: hint = _fmt_known(example_tc, fmt) elif name.startswith("mcp_"): diff --git a/nanobot/utils/__init__.py b/nanobot/utils/__init__.py index 46f02acbd..9ad157c2e 100644 --- a/nanobot/utils/__init__.py +++ b/nanobot/utils/__init__.py @@ -1,5 +1,6 @@ """Utility functions for nanobot.""" from nanobot.utils.helpers import ensure_dir +from nanobot.utils.path import abbreviate_path -__all__ = ["ensure_dir"] +__all__ = ["ensure_dir", "abbreviate_path"] diff --git a/nanobot/utils/path.py b/nanobot/utils/path.py index d23836e08..32591a471 100644 --- a/nanobot/utils/path.py +++ b/nanobot/utils/path.py @@ -89,7 +89,8 @@ def _abbreviate_url(url: str, max_len: int = 40) -> str: budget = max_len - len(domain) - len(basename) - 4 # "…/" + "/" if budget < 0: - return domain + "/\u2026" + basename[: max_len - len(domain) - 4] + trunc = max_len - len(domain) - 5 # "…/" + "/" + return domain + "/\u2026/" + (basename[:trunc] if trunc > 0 else "") # Build abbreviated path kept: list[str] = [] diff --git a/tests/agent/test_tool_hint.py b/tests/agent/test_tool_hint.py index f0c324083..caaa37ee3 100644 --- a/tests/agent/test_tool_hint.py +++ b/tests/agent/test_tool_hint.py @@ -4,7 +4,7 @@ from nanobot.agent.loop import AgentLoop from nanobot.providers.base import ToolCallRequest -def _tc(name: str, args: dict) -> ToolCallRequest: +def _tc(name: str, args) -> ToolCallRequest: return ToolCallRequest(id="c1", name=name, arguments=args) @@ -147,3 +147,51 @@ class TestToolHintMultipleCalls: assert 'grep "TODO"' in result assert "read main.py" in result assert ", " in result + + +class TestToolHintEdgeCases: + """Test edge cases and defensive handling (G1, G2).""" + + def test_known_tool_empty_list_args(self): + """C1/G1: Empty list arguments should not crash.""" + result = AgentLoop._tool_hint([_tc("read_file", [])]) + assert result == "read_file" + + def test_known_tool_none_args(self): + """G2: None arguments should not crash.""" + result = AgentLoop._tool_hint([_tc("read_file", None)]) + assert result == "read_file" + + def test_fallback_empty_list_args(self): + """C1: Empty list args in fallback should not crash.""" + result = AgentLoop._tool_hint([_tc("custom_tool", [])]) + assert result == "custom_tool" + + def test_fallback_none_args(self): + """G2: None args in fallback should not crash.""" + result = AgentLoop._tool_hint([_tc("custom_tool", None)]) + assert result == "custom_tool" + + def test_list_dir_registered(self): + """S2: list_dir should use 'ls' format.""" + result = AgentLoop._tool_hint([_tc("list_dir", {"path": "/tmp"})]) + assert result == "ls /tmp" + + +class TestToolHintMixedFolding: + """G4: Mixed folding groups with interleaved same-tool segments.""" + + def test_read_read_grep_grep_read(self): + """read×2, grep×2, read — should produce two separate groups.""" + calls = [ + _tc("read_file", {"path": "a.py"}), + _tc("read_file", {"path": "b.py"}), + _tc("grep", {"pattern": "x"}), + _tc("grep", {"pattern": "y"}), + _tc("read_file", {"path": "c.py"}), + ] + result = AgentLoop._tool_hint(calls) + assert "\u00d7 2" in result + # Should have 3 groups: read×2, grep×2, read + parts = result.split(", ") + assert len(parts) == 3 diff --git a/tests/utils/test_abbreviate_path.py b/tests/utils/test_abbreviate_path.py index d8ee0a186..573ca0a92 100644 --- a/tests/utils/test_abbreviate_path.py +++ b/tests/utils/test_abbreviate_path.py @@ -77,3 +77,29 @@ class TestAbbreviatePathURLs: def test_short_url_unchanged(self): url = "https://example.com/api" assert abbreviate_path(url) == url + + def test_url_no_path_just_domain(self): + """G3: URL with no path should return as-is if short enough.""" + url = "https://example.com" + assert abbreviate_path(url) == url + + def test_url_with_query_string(self): + """G3: URL with query params should abbreviate path part.""" + url = "https://example.com/api/v2/endpoint?key=value&other=123" + result = abbreviate_path(url, max_len=40) + assert "example.com" in result + assert "\u2026" in result + + def test_url_very_long_basename(self): + """G3: URL with very long basename should truncate basename.""" + url = "https://example.com/path/very_long_resource_name_file.json" + result = abbreviate_path(url, max_len=35) + assert "example.com" in result + assert "\u2026" in result + + def test_url_negative_budget_consistent_format(self): + """I3: Negative budget should still produce domain/…/basename format.""" + url = "https://a.co/very/deep/path/with/lots/of/segments/and/a/long/basename.txt" + result = abbreviate_path(url, max_len=20) + assert "a.co" in result + assert "/\u2026/" in result