mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 08:02:30 +00:00
fix(agent): tighten ask_user CLI handling
Made-with: Cursor
This commit is contained in:
parent
3b1ea99ee1
commit
403ce23d22
@ -20,7 +20,13 @@ from nanobot.agent.memory import Consolidator, Dream
|
|||||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.ask import AskUserTool
|
from nanobot.agent.tools.ask import (
|
||||||
|
AskUserTool,
|
||||||
|
ask_user_options_from_messages,
|
||||||
|
ask_user_outbound,
|
||||||
|
ask_user_tool_result_messages,
|
||||||
|
pending_ask_user_id,
|
||||||
|
)
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
@ -54,7 +60,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
UNIFIED_SESSION_KEY = "unified:default"
|
UNIFIED_SESSION_KEY = "unified:default"
|
||||||
BUTTON_CHANNELS = frozenset({"telegram"})
|
|
||||||
|
|
||||||
|
|
||||||
class _LoopHook(AgentHook):
|
class _LoopHook(AgentHook):
|
||||||
@ -410,69 +415,6 @@ class AgentLoop:
|
|||||||
return UNIFIED_SESSION_KEY
|
return UNIFIED_SESSION_KEY
|
||||||
return msg.session_key
|
return msg.session_key
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
|
||||||
function = tool_call.get("function")
|
|
||||||
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
|
||||||
return function["name"]
|
|
||||||
name = tool_call.get("name")
|
|
||||||
return name if isinstance(name, str) else ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
function = tool_call.get("function")
|
|
||||||
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
|
||||||
if isinstance(raw, dict):
|
|
||||||
return raw
|
|
||||||
if isinstance(raw, str):
|
|
||||||
try:
|
|
||||||
parsed = json.loads(raw)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {}
|
|
||||||
return parsed if isinstance(parsed, dict) else {}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _pending_ask_user_id(self, history: list[dict[str, Any]]) -> str | None:
|
|
||||||
pending: dict[str, str] = {}
|
|
||||||
for message in history:
|
|
||||||
if message.get("role") == "assistant":
|
|
||||||
for tool_call in message.get("tool_calls") or []:
|
|
||||||
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
|
||||||
pending[tool_call["id"]] = self._tool_call_name(tool_call)
|
|
||||||
elif message.get("role") == "tool":
|
|
||||||
tool_call_id = message.get("tool_call_id")
|
|
||||||
if isinstance(tool_call_id, str):
|
|
||||||
pending.pop(tool_call_id, None)
|
|
||||||
for tool_call_id, name in reversed(pending.items()):
|
|
||||||
if name == "ask_user":
|
|
||||||
return tool_call_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _ask_user_options_from_messages(self, messages: list[dict[str, Any]]) -> list[str]:
|
|
||||||
for message in reversed(messages):
|
|
||||||
if message.get("role") != "assistant":
|
|
||||||
continue
|
|
||||||
for tool_call in reversed(message.get("tool_calls") or []):
|
|
||||||
if not isinstance(tool_call, dict) or self._tool_call_name(tool_call) != "ask_user":
|
|
||||||
continue
|
|
||||||
options = self._tool_call_arguments(tool_call).get("options")
|
|
||||||
if isinstance(options, list):
|
|
||||||
return [str(option) for option in options if isinstance(option, str)]
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _ask_user_outbound(
|
|
||||||
content: str | None,
|
|
||||||
options: list[str],
|
|
||||||
channel: str,
|
|
||||||
) -> tuple[str | None, list[list[str]]]:
|
|
||||||
if not options:
|
|
||||||
return content, []
|
|
||||||
if channel in BUTTON_CHANNELS:
|
|
||||||
return content, [options]
|
|
||||||
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
|
||||||
return f"{content}\n\n{option_text}" if content else option_text, []
|
|
||||||
|
|
||||||
async def _run_agent_loop(
|
async def _run_agent_loop(
|
||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
@ -874,8 +816,8 @@ class AgentLoop:
|
|||||||
self._clear_runtime_checkpoint(session)
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||||
options = self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||||
content, buttons = self._ask_user_outbound(
|
content, buttons = ask_user_outbound(
|
||||||
final_content or "Background task completed.",
|
final_content or "Background task completed.",
|
||||||
options,
|
options,
|
||||||
channel,
|
channel,
|
||||||
@ -923,18 +865,14 @@ class AgentLoop:
|
|||||||
|
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
|
|
||||||
pending_ask_id = self._pending_ask_user_id(history)
|
pending_ask_id = pending_ask_user_id(history)
|
||||||
if pending_ask_id:
|
if pending_ask_id:
|
||||||
initial_messages = [
|
initial_messages = ask_user_tool_result_messages(
|
||||||
{"role": "system", "content": self.context.build_system_prompt(channel=msg.channel)},
|
self.context.build_system_prompt(channel=msg.channel),
|
||||||
*history,
|
history,
|
||||||
{
|
pending_ask_id,
|
||||||
"role": "tool",
|
msg.content,
|
||||||
"tool_call_id": pending_ask_id,
|
)
|
||||||
"name": "ask_user",
|
|
||||||
"content": msg.content,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
initial_messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
@ -1030,12 +968,12 @@ class AgentLoop:
|
|||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
|
|
||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
final_content, buttons = self._ask_user_outbound(
|
final_content, buttons = ask_user_outbound(
|
||||||
final_content,
|
final_content,
|
||||||
self._ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
|
ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
|
||||||
msg.channel,
|
msg.channel,
|
||||||
)
|
)
|
||||||
if on_stream is not None and stop_reason != "error":
|
if on_stream is not None and stop_reason not in {"ask_user", "error"}:
|
||||||
meta["_streamed"] = True
|
meta["_streamed"] = True
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
|
|||||||
@ -278,17 +278,22 @@ class AgentRunner:
|
|||||||
self._accumulate_usage(usage, raw_usage)
|
self._accumulate_usage(usage, raw_usage)
|
||||||
|
|
||||||
if response.should_execute_tools:
|
if response.should_execute_tools:
|
||||||
|
tool_calls = list(response.tool_calls)
|
||||||
|
ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None)
|
||||||
|
if ask_index is not None:
|
||||||
|
tool_calls = tool_calls[: ask_index + 1]
|
||||||
|
context.tool_calls = list(tool_calls)
|
||||||
if hook.wants_streaming():
|
if hook.wants_streaming():
|
||||||
await hook.on_stream_end(context, resuming=True)
|
await hook.on_stream_end(context, resuming=True)
|
||||||
|
|
||||||
assistant_message = build_assistant_message(
|
assistant_message = build_assistant_message(
|
||||||
response.content or "",
|
response.content or "",
|
||||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
tool_calls=[tc.to_openai_tool_call() for tc in tool_calls],
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
)
|
)
|
||||||
messages.append(assistant_message)
|
messages.append(assistant_message)
|
||||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
tools_used.extend(tc.name for tc in tool_calls)
|
||||||
await self._emit_checkpoint(
|
await self._emit_checkpoint(
|
||||||
spec,
|
spec,
|
||||||
{
|
{
|
||||||
@ -297,7 +302,7 @@ class AgentRunner:
|
|||||||
"model": spec.model,
|
"model": spec.model,
|
||||||
"assistant_message": assistant_message,
|
"assistant_message": assistant_message,
|
||||||
"completed_tool_results": [],
|
"completed_tool_results": [],
|
||||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
"pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -305,14 +310,14 @@ class AgentRunner:
|
|||||||
|
|
||||||
results, new_events, fatal_error = await self._execute_tools(
|
results, new_events, fatal_error = await self._execute_tools(
|
||||||
spec,
|
spec,
|
||||||
response.tool_calls,
|
tool_calls,
|
||||||
external_lookup_counts,
|
external_lookup_counts,
|
||||||
)
|
)
|
||||||
tool_events.extend(new_events)
|
tool_events.extend(new_events)
|
||||||
context.tool_results = list(results)
|
context.tool_results = list(results)
|
||||||
context.tool_events = list(new_events)
|
context.tool_events = list(new_events)
|
||||||
completed_tool_results: list[dict[str, Any]] = []
|
completed_tool_results: list[dict[str, Any]] = []
|
||||||
for tool_call, result in zip(response.tool_calls, results):
|
for tool_call, result in zip(tool_calls, results):
|
||||||
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
|
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
|
||||||
continue
|
continue
|
||||||
tool_message = {
|
tool_message = {
|
||||||
@ -700,7 +705,7 @@ class AgentRunner:
|
|||||||
tool_call: ToolCallRequest,
|
tool_call: ToolCallRequest,
|
||||||
external_lookup_counts: dict[str, int],
|
external_lookup_counts: dict[str, int],
|
||||||
) -> tuple[Any, dict[str, str], BaseException | None]:
|
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
hint = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
lookup_error = repeated_external_lookup_error(
|
lookup_error = repeated_external_lookup_error(
|
||||||
tool_call.name,
|
tool_call.name,
|
||||||
tool_call.arguments,
|
tool_call.arguments,
|
||||||
@ -713,8 +718,8 @@ class AgentRunner:
|
|||||||
"detail": "repeated external lookup blocked",
|
"detail": "repeated external lookup blocked",
|
||||||
}
|
}
|
||||||
if spec.fail_on_tool_error:
|
if spec.fail_on_tool_error:
|
||||||
return lookup_error + _HINT, event, RuntimeError(lookup_error)
|
return lookup_error + hint, event, RuntimeError(lookup_error)
|
||||||
return lookup_error + _HINT, event, None
|
return lookup_error + hint, event, None
|
||||||
prepare_call = getattr(spec.tools, "prepare_call", None)
|
prepare_call = getattr(spec.tools, "prepare_call", None)
|
||||||
tool, params, prep_error = None, tool_call.arguments, None
|
tool, params, prep_error = None, tool_call.arguments, None
|
||||||
if callable(prepare_call):
|
if callable(prepare_call):
|
||||||
@ -730,7 +735,7 @@ class AgentRunner:
|
|||||||
"status": "error",
|
"status": "error",
|
||||||
"detail": prep_error.split(": ", 1)[-1][:120],
|
"detail": prep_error.split(": ", 1)[-1][:120],
|
||||||
}
|
}
|
||||||
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
return prep_error + hint, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||||
try:
|
try:
|
||||||
if tool is not None:
|
if tool is not None:
|
||||||
result = await tool.execute(**params)
|
result = await tool.execute(**params)
|
||||||
@ -758,8 +763,8 @@ class AgentRunner:
|
|||||||
"detail": result.replace("\n", " ").strip()[:120],
|
"detail": result.replace("\n", " ").strip()[:120],
|
||||||
}
|
}
|
||||||
if spec.fail_on_tool_error:
|
if spec.fail_on_tool_error:
|
||||||
return result + _HINT, event, RuntimeError(result)
|
return result + hint, event, RuntimeError(result)
|
||||||
return result + _HINT, event, None
|
return result + hint, event, None
|
||||||
|
|
||||||
detail = "" if result is None else str(result)
|
detail = "" if result is None else str(result)
|
||||||
detail = detail.replace("\n", " ").strip()
|
detail = detail.replace("\n", " ").strip()
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
"""Tool for pausing a turn until the user answers."""
|
"""Tool for pausing a turn until the user answers."""
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||||
|
|
||||||
|
BUTTON_CHANNELS = frozenset({"telegram"})
|
||||||
|
|
||||||
|
|
||||||
class AskUserInterrupt(BaseException):
|
class AskUserInterrupt(BaseException):
|
||||||
"""Internal signal: the runner should stop and wait for user input."""
|
"""Internal signal: the runner should stop and wait for user input."""
|
||||||
@ -48,3 +51,86 @@ class AskUserTool(Tool):
|
|||||||
|
|
||||||
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
|
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
|
||||||
raise AskUserInterrupt(question=question, options=options)
|
raise AskUserInterrupt(question=question, options=options)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||||
|
function = tool_call.get("function")
|
||||||
|
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
||||||
|
return function["name"]
|
||||||
|
name = tool_call.get("name")
|
||||||
|
return name if isinstance(name, str) else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
function = tool_call.get("function")
|
||||||
|
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
return raw
|
||||||
|
if isinstance(raw, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None:
|
||||||
|
pending: dict[str, str] = {}
|
||||||
|
for message in history:
|
||||||
|
if message.get("role") == "assistant":
|
||||||
|
for tool_call in message.get("tool_calls") or []:
|
||||||
|
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
||||||
|
pending[tool_call["id"]] = _tool_call_name(tool_call)
|
||||||
|
elif message.get("role") == "tool":
|
||||||
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
if isinstance(tool_call_id, str):
|
||||||
|
pending.pop(tool_call_id, None)
|
||||||
|
for tool_call_id, name in reversed(pending.items()):
|
||||||
|
if name == "ask_user":
|
||||||
|
return tool_call_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def ask_user_tool_result_messages(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tool_call_id: str,
|
||||||
|
content: str,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
*history,
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"name": "ask_user",
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]:
|
||||||
|
for message in reversed(messages):
|
||||||
|
if message.get("role") != "assistant":
|
||||||
|
continue
|
||||||
|
for tool_call in reversed(message.get("tool_calls") or []):
|
||||||
|
if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user":
|
||||||
|
continue
|
||||||
|
options = _tool_call_arguments(tool_call).get("options")
|
||||||
|
if isinstance(options, list):
|
||||||
|
return [str(option) for option in options if isinstance(option, str)]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def ask_user_outbound(
|
||||||
|
content: str | None,
|
||||||
|
options: list[str],
|
||||||
|
channel: str,
|
||||||
|
) -> tuple[str | None, list[list[str]]]:
|
||||||
|
if not options:
|
||||||
|
return content, []
|
||||||
|
if channel in BUTTON_CHANNELS:
|
||||||
|
return content, [options]
|
||||||
|
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
||||||
|
return f"{content}\n\n{option_text}" if content else option_text, []
|
||||||
|
|||||||
@ -212,12 +212,16 @@ async def _print_interactive_response(
|
|||||||
|
|
||||||
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||||
"""Print a CLI progress line, pausing the spinner if needed."""
|
"""Print a CLI progress line, pausing the spinner if needed."""
|
||||||
|
if not text.strip():
|
||||||
|
return
|
||||||
with thinking.pause() if thinking else nullcontext():
|
with thinking.pause() if thinking else nullcontext():
|
||||||
console.print(f" [dim]↳ {text}[/dim]")
|
console.print(f" [dim]↳ {text}[/dim]")
|
||||||
|
|
||||||
|
|
||||||
async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||||
"""Print an interactive progress line, pausing the spinner if needed."""
|
"""Print an interactive progress line, pausing the spinner if needed."""
|
||||||
|
if not text.strip():
|
||||||
|
return
|
||||||
with thinking.pause() if thinking else nullcontext():
|
with thinking.pause() if thinking else nullcontext():
|
||||||
await _print_interactive_line(text)
|
await _print_interactive_line(text)
|
||||||
|
|
||||||
|
|||||||
@ -15,10 +15,15 @@ from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequ
|
|||||||
|
|
||||||
|
|
||||||
def _make_provider(chat_with_retry):
|
def _make_provider(chat_with_retry):
|
||||||
|
async def chat_stream_with_retry(**kwargs):
|
||||||
|
kwargs.pop("on_content_delta", None)
|
||||||
|
return await chat_with_retry(**kwargs)
|
||||||
|
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.generation = GenerationSettings()
|
provider.generation = GenerationSettings()
|
||||||
provider.chat_with_retry = chat_with_retry
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
@ -88,7 +93,8 @@ async def test_runner_pauses_on_ask_user_without_executing_later_tools():
|
|||||||
assert "ask_user" in result.tools_used
|
assert "ask_user" in result.tools_used
|
||||||
assert later.called is False
|
assert later.called is False
|
||||||
assert result.messages[-1]["role"] == "assistant"
|
assert result.messages[-1]["role"] == "assistant"
|
||||||
assert result.messages[-1]["tool_calls"][0]["function"]["name"] == "ask_user"
|
tool_calls = result.messages[-1]["tool_calls"]
|
||||||
|
assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"]
|
||||||
assert not any(message.get("name") == "ask_user" for message in result.messages)
|
assert not any(message.get("name") == "ask_user" for message in result.messages)
|
||||||
|
|
||||||
|
|
||||||
@ -122,13 +128,22 @@ async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path):
|
|||||||
model="test-model",
|
model="test-model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_stream_end(**kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
first = await loop._process_message(
|
first = await loop._process_message(
|
||||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up")
|
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"),
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert first is not None
|
assert first is not None
|
||||||
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
|
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
|
||||||
assert first.buttons == []
|
assert first.buttons == []
|
||||||
|
assert "_streamed" not in first.metadata
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:direct")
|
session = loop.sessions.get_or_create("cli:direct")
|
||||||
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
|
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user