mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-08 04:03:38 +00:00
feat(agent): prompt behavior directives, tool descriptions, and loop robustness
This commit is contained in:
parent
ef0284a4e0
commit
edb821e10d
@ -27,9 +27,13 @@ class ContextBuilder:
|
||||
self.memory = MemoryStore(workspace)
|
||||
self.skills = SkillsLoader(workspace)
|
||||
|
||||
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
||||
def build_system_prompt(
|
||||
self,
|
||||
skill_names: list[str] | None = None,
|
||||
channel: str | None = None,
|
||||
) -> str:
|
||||
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
|
||||
parts = [self._get_identity()]
|
||||
parts = [self._get_identity(channel=channel)]
|
||||
|
||||
bootstrap = self._load_bootstrap_files()
|
||||
if bootstrap:
|
||||
@ -58,7 +62,7 @@ class ContextBuilder:
|
||||
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def _get_identity(self) -> str:
|
||||
def _get_identity(self, channel: str | None = None) -> str:
|
||||
"""Get the core identity section."""
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
@ -69,6 +73,7 @@ class ContextBuilder:
|
||||
workspace_path=workspace_path,
|
||||
runtime=runtime,
|
||||
platform_policy=render_template("agent/platform_policy.md", system=system),
|
||||
channel=channel or "",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -128,7 +133,7 @@ class ContextBuilder:
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
messages = [
|
||||
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||
{"role": "system", "content": self.build_system_prompt(skill_names, channel=channel)},
|
||||
*history,
|
||||
]
|
||||
if messages[-1].get("role") == current_role:
|
||||
|
||||
@ -24,6 +24,7 @@ from nanobot.utils.helpers import (
|
||||
from nanobot.utils.runtime import (
|
||||
EMPTY_FINAL_RESPONSE_MESSAGE,
|
||||
build_finalization_retry_message,
|
||||
build_length_recovery_message,
|
||||
ensure_nonempty_tool_result,
|
||||
is_blank_text,
|
||||
repeated_external_lookup_error,
|
||||
@ -31,7 +32,15 @@ from nanobot.utils.runtime import (
|
||||
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
_MAX_EMPTY_RETRIES = 2
|
||||
_MAX_LENGTH_RECOVERIES = 3
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
_MICROCOMPACT_KEEP_RECENT = 10
|
||||
_MICROCOMPACT_MIN_CHARS = 500
|
||||
_COMPACTABLE_TOOLS = frozenset({
|
||||
"read_file", "exec", "grep", "glob",
|
||||
"web_search", "web_fetch", "list_dir",
|
||||
})
|
||||
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
"""Configuration for a single agent execution."""
|
||||
@ -88,9 +97,12 @@ class AgentRunner:
|
||||
tool_events: list[dict[str, str]] = []
|
||||
external_lookup_counts: dict[str, int] = {}
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
|
||||
for iteration in range(spec.max_iterations):
|
||||
try:
|
||||
messages = self._backfill_missing_tool_results(messages)
|
||||
messages = self._microcompact(messages)
|
||||
messages = self._apply_tool_result_budget(spec, messages)
|
||||
messages_for_model = self._snip_history(spec, messages)
|
||||
except Exception as exc:
|
||||
@ -181,6 +193,7 @@ class AgentRunner:
|
||||
},
|
||||
)
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -216,6 +229,27 @@ class AgentRunner:
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
clean = hook.finalize_content(context, response.content)
|
||||
|
||||
if response.finish_reason == "length" and not is_blank_text(clean):
|
||||
length_recovery_count += 1
|
||||
if length_recovery_count <= _MAX_LENGTH_RECOVERIES:
|
||||
logger.info(
|
||||
"Output truncated on turn {} for {} ({}/{}); continuing",
|
||||
iteration,
|
||||
spec.session_key or "default",
|
||||
length_recovery_count,
|
||||
_MAX_LENGTH_RECOVERIES,
|
||||
)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=True)
|
||||
messages.append(build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
messages.append(build_length_recovery_message())
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
|
||||
@ -515,6 +549,73 @@ class AgentRunner:
|
||||
return truncate_text(content, spec.max_tool_result_chars)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def _backfill_missing_tool_results(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Insert synthetic error results for orphaned tool_use blocks."""
|
||||
declared: list[tuple[int, str, str]] = [] # (assistant_idx, call_id, name)
|
||||
fulfilled: set[str] = set()
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
name = ""
|
||||
func = tc.get("function")
|
||||
if isinstance(func, dict):
|
||||
name = func.get("name", "")
|
||||
declared.append((idx, str(tc["id"]), name))
|
||||
elif role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid:
|
||||
fulfilled.add(str(tid))
|
||||
|
||||
missing = [(ai, cid, name) for ai, cid, name in declared if cid not in fulfilled]
|
||||
if not missing:
|
||||
return messages
|
||||
|
||||
updated = list(messages)
|
||||
offset = 0
|
||||
for assistant_idx, call_id, name in missing:
|
||||
insert_at = assistant_idx + 1 + offset
|
||||
while insert_at < len(updated) and updated[insert_at].get("role") == "tool":
|
||||
insert_at += 1
|
||||
updated.insert(insert_at, {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": _BACKFILL_CONTENT,
|
||||
})
|
||||
offset += 1
|
||||
return updated
|
||||
|
||||
@staticmethod
|
||||
def _microcompact(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Replace old compactable tool results with one-line summaries."""
|
||||
compactable_indices: list[int] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
if msg.get("role") == "tool" and msg.get("name") in _COMPACTABLE_TOOLS:
|
||||
compactable_indices.append(idx)
|
||||
|
||||
if len(compactable_indices) <= _MICROCOMPACT_KEEP_RECENT:
|
||||
return messages
|
||||
|
||||
stale = compactable_indices[: len(compactable_indices) - _MICROCOMPACT_KEEP_RECENT]
|
||||
updated: list[dict[str, Any]] | None = None
|
||||
for idx in stale:
|
||||
msg = messages[idx]
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, str) or len(content) < _MICROCOMPACT_MIN_CHARS:
|
||||
continue
|
||||
name = msg.get("name", "tool")
|
||||
summary = f"[{name} result omitted from context]"
|
||||
if updated is None:
|
||||
updated = [dict(m) for m in messages]
|
||||
updated[idx]["content"] = summary
|
||||
|
||||
return updated if updated is not None else messages
|
||||
|
||||
def _apply_tool_result_budget(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
|
||||
@ -89,8 +89,10 @@ class ReadFileTool(_FsTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read the contents of a file. Returns numbered lines. "
|
||||
"Use offset and limit to paginate through large files."
|
||||
"Read a text file. Output format: LINE_NUM|CONTENT. "
|
||||
"Use offset and limit for large files. "
|
||||
"Cannot read binary files or images. "
|
||||
"Reads exceeding ~128K chars are truncated."
|
||||
)
|
||||
|
||||
@property
|
||||
@ -175,7 +177,11 @@ class WriteFileTool(_FsTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Write content to a file at the given path. Creates parent directories if needed."
|
||||
return (
|
||||
"Write content to a file. Overwrites if the file already exists; "
|
||||
"creates parent directories as needed. "
|
||||
"For partial edits, prefer edit_file instead."
|
||||
)
|
||||
|
||||
async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str:
|
||||
try:
|
||||
@ -243,8 +249,9 @@ class EditFileTool(_FsTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Supports minor whitespace/line-ending differences. "
|
||||
"Set replace_all=true to replace every occurrence."
|
||||
"Tolerates minor whitespace/indentation differences. "
|
||||
"If old_text matches multiple times, you must provide more context "
|
||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
|
||||
@ -142,8 +142,9 @@ class GlobTool(_SearchTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Find files matching a glob pattern. "
|
||||
"Simple patterns like '*.py' match by filename recursively."
|
||||
"Find files matching a glob pattern (e.g. '*.py', 'tests/**/test_*.py'). "
|
||||
"Results are sorted by modification time (newest first). "
|
||||
"Skips .git, node_modules, __pycache__, and other noise directories."
|
||||
)
|
||||
|
||||
@property
|
||||
@ -261,9 +262,10 @@ class GrepTool(_SearchTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search file contents with a regex-like pattern. "
|
||||
"Supports optional glob filtering, structured output modes, "
|
||||
"type filters, pagination, and surrounding context lines."
|
||||
"Search file contents with a regex pattern. "
|
||||
"Default output_mode is files_with_matches (file paths only); "
|
||||
"use content mode for matching lines with context. "
|
||||
"Skips binary and files >2 MB. Supports glob/type filtering."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -74,7 +74,13 @@ class ExecTool(Tool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
return (
|
||||
"Execute a shell command and return its output. "
|
||||
"Prefer read_file/write_file/edit_file over cat/echo/sed, "
|
||||
"and grep/glob over shell find/grep. "
|
||||
"Use -y or --yes flags to avoid interactive prompts. "
|
||||
"Output is truncated at 10 000 chars; timeout defaults to 60s."
|
||||
)
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
|
||||
@ -84,7 +84,11 @@ class WebSearchTool(Tool):
|
||||
"""Search the web using configured provider."""
|
||||
|
||||
name = "web_search"
|
||||
description = "Search the web. Returns titles, URLs, and snippets."
|
||||
description = (
|
||||
"Search the web. Returns titles, URLs, and snippets. "
|
||||
"count defaults to 5 (max 10). "
|
||||
"Use web_fetch to read a specific page in full."
|
||||
)
|
||||
|
||||
def __init__(self, config: WebSearchConfig | None = None, proxy: str | None = None):
|
||||
from nanobot.config.schema import WebSearchConfig
|
||||
@ -239,7 +243,11 @@ class WebFetchTool(Tool):
|
||||
"""Fetch and extract content from a URL."""
|
||||
|
||||
name = "web_fetch"
|
||||
description = "Fetch URL and extract readable content (HTML → markdown/text)."
|
||||
description = (
|
||||
"Fetch a URL and extract readable content (HTML → markdown/text). "
|
||||
"Output is capped at maxChars (default 50 000). "
|
||||
"Works for most web pages and docs; may fail on login-walled or JS-heavy sites."
|
||||
)
|
||||
|
||||
def __init__(self, max_chars: int = 50000, proxy: str | None = None):
|
||||
self.max_chars = max_chars
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
# Agent Instructions
|
||||
|
||||
You are a helpful AI assistant. Be concise, accurate, and friendly.
|
||||
|
||||
## Scheduled Reminders
|
||||
|
||||
Before scheduling reminders, check available skills and follow skill guidance first.
|
||||
|
||||
@ -2,20 +2,7 @@
|
||||
|
||||
I am nanobot 🐈, a personal AI assistant.
|
||||
|
||||
## Personality
|
||||
|
||||
- Helpful and friendly
|
||||
- Concise and to the point
|
||||
- Curious and eager to learn
|
||||
|
||||
## Values
|
||||
|
||||
- Accuracy over speed
|
||||
- User privacy and safety
|
||||
- Transparency in actions
|
||||
|
||||
## Communication Style
|
||||
|
||||
- Be clear and direct
|
||||
- Explain reasoning when helpful
|
||||
- Ask clarifying questions when needed
|
||||
I solve problems by doing, not by describing what I would do.
|
||||
I keep responses short unless depth is asked for.
|
||||
I say what I know, flag what I don't, and never fake confidence.
|
||||
I treat the user's time as the scarcest resource.
|
||||
|
||||
@ -12,15 +12,32 @@ Your workspace is at: {{ workspace_path }}
|
||||
- Custom skills: {{ workspace_path }}/skills/{% raw %}{skill-name}{% endraw %}/SKILL.md
|
||||
|
||||
{{ platform_policy }}
|
||||
{% if channel == 'telegram' or channel == 'qq' or channel == 'discord' %}
|
||||
## Format Hint
|
||||
This conversation is on a messaging app. Use short paragraphs. Avoid large headings (#, ##). Use **bold** sparingly. No tables — use plain lists.
|
||||
{% elif channel == 'whatsapp' or channel == 'sms' %}
|
||||
## Format Hint
|
||||
This conversation is on a text messaging platform that does not render markdown. Use plain text only.
|
||||
{% elif channel == 'email' %}
|
||||
## Format Hint
|
||||
This conversation is via email. Structure with clear sections. Markdown may not render — keep formatting simple.
|
||||
{% elif channel == 'cli' or channel == 'mochat' %}
|
||||
## Format Hint
|
||||
Output is rendered in a terminal. Avoid markdown headings and tables. Use plain text with minimal formatting.
|
||||
{% endif %}
|
||||
|
||||
## nanobot Guidelines
|
||||
- State intent before tool calls, but NEVER predict or claim results before receiving them.
|
||||
- Before modifying a file, read it first. Do not assume files or directories exist.
|
||||
- After writing or editing a file, re-read it if accuracy matters.
|
||||
- If a tool call fails, analyze the error before retrying with a different approach.
|
||||
- Ask for clarification when the request is ambiguous.
|
||||
- Prefer built-in `grep` / `glob` tools for workspace search before falling back to `exec`.
|
||||
- On broad searches, use `grep(output_mode="count")` or `grep(output_mode="files_with_matches")` to scope the result set before requesting full content.
|
||||
## Execution Rules
|
||||
|
||||
- Act, don't narrate. If you can do it with a tool, do it now — never end a turn with just a plan or promise.
|
||||
- Read before you write. Do not assume a file exists or contains what you expect.
|
||||
- If a tool call fails, diagnose the error and retry with a different approach before reporting failure.
|
||||
- When information is missing, look it up with tools first. Only ask the user when tools cannot answer.
|
||||
- After multi-step changes, verify the result (re-read the file, run the test, check the output).
|
||||
|
||||
## Search & Discovery
|
||||
|
||||
- Prefer built-in `grep` / `glob` over `exec` for workspace search.
|
||||
- On broad searches, use `grep(output_mode="count")` to scope before requesting full content.
|
||||
{% include 'agent/_snippets/untrusted_content.md' %}
|
||||
|
||||
Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel.
|
||||
|
||||
@ -19,6 +19,11 @@ FINALIZATION_RETRY_PROMPT = (
|
||||
"Please provide your response to the user based on the conversation above."
|
||||
)
|
||||
|
||||
LENGTH_RECOVERY_PROMPT = (
|
||||
"Output limit reached. Continue exactly where you left off "
|
||||
"— no recap, no apology. Break remaining work into smaller steps if needed."
|
||||
)
|
||||
|
||||
|
||||
def empty_tool_result_message(tool_name: str) -> str:
|
||||
"""Short prompt-safe marker for tools that completed without visible output."""
|
||||
@ -50,6 +55,11 @@ def build_finalization_retry_message() -> dict[str, str]:
|
||||
return {"role": "user", "content": FINALIZATION_RETRY_PROMPT}
|
||||
|
||||
|
||||
def build_length_recovery_message() -> dict[str, str]:
|
||||
"""Prompt the model to continue after hitting output token limit."""
|
||||
return {"role": "user", "content": LENGTH_RECOVERY_PROMPT}
|
||||
|
||||
|
||||
def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None:
|
||||
"""Stable signature for repeated external lookups we want to throttle."""
|
||||
if tool_name == "web_fetch":
|
||||
|
||||
@ -148,6 +148,63 @@ def test_partial_dream_processing_shows_only_remainder(tmp_path) -> None:
|
||||
assert "recent question about K8s" in prompt
|
||||
|
||||
|
||||
def test_execution_rules_in_system_prompt(tmp_path) -> None:
|
||||
"""New execution rules should appear in the system prompt."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
assert "Act, don't narrate" in prompt
|
||||
assert "Read before you write" in prompt
|
||||
assert "verify the result" in prompt
|
||||
|
||||
|
||||
def test_channel_format_hint_telegram(tmp_path) -> None:
|
||||
"""Telegram channel should get messaging-app format hint."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
prompt = builder.build_system_prompt(channel="telegram")
|
||||
assert "Format Hint" in prompt
|
||||
assert "messaging app" in prompt
|
||||
|
||||
|
||||
def test_channel_format_hint_whatsapp(tmp_path) -> None:
|
||||
"""WhatsApp should get plain-text format hint."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
prompt = builder.build_system_prompt(channel="whatsapp")
|
||||
assert "Format Hint" in prompt
|
||||
assert "plain text only" in prompt
|
||||
|
||||
|
||||
def test_channel_format_hint_absent_for_unknown(tmp_path) -> None:
|
||||
"""Unknown or None channel should not inject a format hint."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
prompt = builder.build_system_prompt(channel=None)
|
||||
assert "Format Hint" not in prompt
|
||||
|
||||
prompt2 = builder.build_system_prompt(channel="feishu")
|
||||
assert "Format Hint" not in prompt2
|
||||
|
||||
|
||||
def test_build_messages_passes_channel_to_system_prompt(tmp_path) -> None:
|
||||
"""build_messages should pass channel through to build_system_prompt."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
messages = builder.build_messages(
|
||||
history=[], current_message="hi",
|
||||
channel="telegram", chat_id="123",
|
||||
)
|
||||
system = messages[0]["content"]
|
||||
assert "Format Hint" in system
|
||||
assert "messaging app" in system
|
||||
|
||||
|
||||
def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None:
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
@ -999,3 +999,256 @@ async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
|
||||
assert len(captured_usage) == 1
|
||||
assert captured_usage[0]["cached_tokens"] == 150
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Length recovery (auto-continue on finish_reason == "length")
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_length_recovery_continues_from_truncated_output():
|
||||
"""When finish_reason is 'length', runner should insert a continuation
|
||||
prompt and retry, stitching partial outputs into the final result."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] <= 2:
|
||||
return LLMResponse(
|
||||
content=f"part{call_count['n']} ",
|
||||
finish_reason="length",
|
||||
usage={},
|
||||
)
|
||||
return LLMResponse(content="final", finish_reason="stop", usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "write a long essay"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "completed"
|
||||
assert result.final_content == "final"
|
||||
assert call_count["n"] == 3
|
||||
roles = [m["role"] for m in result.messages if m["role"] == "user"]
|
||||
assert len(roles) >= 3 # original + 2 recovery prompts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_length_recovery_streaming_calls_on_stream_end_with_resuming():
|
||||
"""During length recovery with streaming, on_stream_end should be called
|
||||
with resuming=True so the hook knows the conversation is continuing."""
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
stream_end_calls: list[bool] = []
|
||||
|
||||
class StreamHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
pass
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, resuming: bool = False) -> None:
|
||||
stream_end_calls.append(resuming)
|
||||
|
||||
async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="partial ", finish_reason="length", usage={})
|
||||
return LLMResponse(content="done", finish_reason="stop", usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "go"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=StreamHook(),
|
||||
))
|
||||
|
||||
assert len(stream_end_calls) == 2
|
||||
assert stream_end_calls[0] is True # length recovery: resuming
|
||||
assert stream_end_calls[1] is False # final response: done
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_length_recovery_gives_up_after_max_retries():
|
||||
"""After _MAX_LENGTH_RECOVERIES attempts the runner should stop retrying."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_LENGTH_RECOVERIES
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return LLMResponse(
|
||||
content=f"chunk{call_count['n']}",
|
||||
finish_reason="length",
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "go"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=20,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert call_count["n"] == _MAX_LENGTH_RECOVERIES + 1
|
||||
assert result.final_content is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backfill missing tool_results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_missing_tool_results_inserts_error():
|
||||
"""Orphaned tool_use (no matching tool_result) should get a synthetic error."""
|
||||
from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_a", "type": "function", "function": {"name": "exec", "arguments": "{}"}},
|
||||
{"id": "call_b", "type": "function", "function": {"name": "read_file", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_a", "name": "exec", "content": "ok"},
|
||||
]
|
||||
result = AgentRunner._backfill_missing_tool_results(messages)
|
||||
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
backfilled = [m for m in tool_msgs if m.get("tool_call_id") == "call_b"]
|
||||
assert len(backfilled) == 1
|
||||
assert backfilled[0]["content"] == _BACKFILL_CONTENT
|
||||
assert backfilled[0]["name"] == "read_file"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_noop_when_complete():
|
||||
"""Complete message chains should not be modified."""
|
||||
from nanobot.agent.runner import AgentRunner
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_x", "type": "function", "function": {"name": "exec", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_x", "name": "exec", "content": "done"},
|
||||
{"role": "assistant", "content": "all good"},
|
||||
]
|
||||
result = AgentRunner._backfill_missing_tool_results(messages)
|
||||
assert result is messages # same object — no copy
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Microcompact (stale tool result compaction)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_replaces_old_tool_results():
|
||||
"""Tool results beyond _MICROCOMPACT_KEEP_RECENT should be summarized."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
long_content = "x" * 600
|
||||
messages: list[dict] = [{"role": "system", "content": "sys"}]
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "read_file", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "read_file",
|
||||
"content": long_content,
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
tool_msgs = [m for m in result if m.get("role") == "tool"]
|
||||
stale_count = total - _MICROCOMPACT_KEEP_RECENT
|
||||
compacted = [m for m in tool_msgs if "omitted from context" in str(m.get("content", ""))]
|
||||
preserved = [m for m in tool_msgs if m.get("content") == long_content]
|
||||
assert len(compacted) == stale_count
|
||||
assert len(preserved) == _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_preserves_short_results():
|
||||
"""Short tool results (< _MICROCOMPACT_MIN_CHARS) should not be replaced."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
messages: list[dict] = []
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "exec", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "exec",
|
||||
"content": "short",
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
assert result is messages # no copy needed — all stale results are short
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microcompact_skips_non_compactable_tools():
|
||||
"""Non-compactable tools (e.g. 'message') should never be replaced."""
|
||||
from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT
|
||||
|
||||
total = _MICROCOMPACT_KEEP_RECENT + 5
|
||||
long_content = "y" * 1000
|
||||
messages: list[dict] = []
|
||||
for i in range(total):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": f"c{i}", "type": "function", "function": {"name": "message", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "tool_call_id": f"c{i}", "name": "message",
|
||||
"content": long_content,
|
||||
})
|
||||
|
||||
result = AgentRunner._microcompact(messages)
|
||||
assert result is messages # no compactable tools found
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user