mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-14 23:19:55 +00:00
feat(agent): enhance message injection handling and content merging
This commit is contained in:
parent
f6c39ec946
commit
cf8381f517
@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
@ -370,16 +370,30 @@ class AgentLoop:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
async def _drain_pending() -> list[InboundMessage]:
|
||||
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
|
||||
"""Non-blocking drain of follow-up messages from the pending queue."""
|
||||
if pending_queue is None:
|
||||
return []
|
||||
items: list[InboundMessage] = []
|
||||
while True:
|
||||
items: list[dict[str, Any]] = []
|
||||
while len(items) < limit:
|
||||
try:
|
||||
items.append(pending_queue.get_nowait())
|
||||
pending_msg = pending_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
user_content = self.context._build_user_content(
|
||||
pending_msg.content,
|
||||
pending_msg.media if pending_msg.media else None,
|
||||
)
|
||||
runtime_ctx = self.context._build_runtime_context(
|
||||
pending_msg.channel,
|
||||
pending_msg.chat_id,
|
||||
self.context.timezone,
|
||||
)
|
||||
if isinstance(user_content, str):
|
||||
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
items.append({"role": "user", "content": merged})
|
||||
return items
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
@ -451,7 +465,7 @@ class AgentLoop:
|
||||
self._pending_queues[effective_key].put_nowait(pending_msg)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"Pending queue full for session {}, dropping follow-up",
|
||||
"Pending queue full for session {}, falling back to queued task",
|
||||
effective_key,
|
||||
)
|
||||
else:
|
||||
@ -459,7 +473,7 @@ class AgentLoop:
|
||||
"Routed follow-up message to pending queue for session {}",
|
||||
effective_key,
|
||||
)
|
||||
continue
|
||||
continue
|
||||
# Compute the effective session key before dispatching
|
||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
@ -697,12 +711,14 @@ class AgentLoop:
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
|
||||
# When follow-up messages were injected mid-turn, the LLM's final
|
||||
# response addresses those follow-ups. Always send the response in
|
||||
# this case, even if MessageTool was used earlier in the turn — the
|
||||
# follow-up response is new content the user hasn't seen.
|
||||
if not had_injections:
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
# When follow-up messages were injected mid-turn, a later natural
|
||||
# language reply may address those follow-ups and should not be
|
||||
# suppressed just because MessageTool was used earlier in the turn.
|
||||
# However, if the turn falls back to the empty-final-response
|
||||
# placeholder, suppress it when the real user-visible output already
|
||||
# came from MessageTool.
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
if not had_injections or stop_reason == "empty_final_response":
|
||||
return None
|
||||
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -94,37 +95,89 @@ class AgentRunner:
|
||||
def __init__(self, provider: LLMProvider):
|
||||
self.provider = provider
|
||||
|
||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[str]:
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
if isinstance(left, str) and isinstance(right, str):
|
||||
return f"{left}\n\n{right}" if left else right
|
||||
|
||||
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(value, list):
|
||||
return [
|
||||
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
|
||||
for item in value
|
||||
]
|
||||
if value is None:
|
||||
return []
|
||||
return [{"type": "text", "text": str(value)}]
|
||||
|
||||
return _to_blocks(left) + _to_blocks(right)
|
||||
|
||||
@classmethod
|
||||
def _append_injected_messages(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
injections: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Append injected user messages while preserving role alternation."""
|
||||
for injection in injections:
|
||||
if (
|
||||
messages
|
||||
and injection.get("role") == "user"
|
||||
and messages[-1].get("role") == "user"
|
||||
):
|
||||
merged = dict(messages[-1])
|
||||
merged["content"] = cls._merge_message_content(
|
||||
merged.get("content"),
|
||||
injection.get("content"),
|
||||
)
|
||||
messages[-1] = merged
|
||||
continue
|
||||
messages.append(injection)
|
||||
|
||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||
"""Drain pending user messages via the injection callback.
|
||||
|
||||
Returns all drained message contents (capped by
|
||||
Returns normalized user messages (capped by
|
||||
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
|
||||
nothing to inject. Messages beyond the cap are logged so they
|
||||
nothing to inject. Messages beyond the cap are logged so they
|
||||
are not silently lost.
|
||||
"""
|
||||
if spec.injection_callback is None:
|
||||
return []
|
||||
try:
|
||||
items = await spec.injection_callback()
|
||||
signature = inspect.signature(spec.injection_callback)
|
||||
accepts_limit = (
|
||||
"limit" in signature.parameters
|
||||
or any(
|
||||
parameter.kind is inspect.Parameter.VAR_KEYWORD
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
)
|
||||
if accepts_limit:
|
||||
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
|
||||
else:
|
||||
items = await spec.injection_callback()
|
||||
except Exception:
|
||||
logger.exception("injection_callback failed")
|
||||
return []
|
||||
if not items:
|
||||
return []
|
||||
# items are InboundMessage objects from _drain_pending
|
||||
texts: list[str] = []
|
||||
injected_messages: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
|
||||
injected_messages.append(item)
|
||||
continue
|
||||
text = getattr(item, "content", str(item))
|
||||
if text.strip():
|
||||
texts.append(text)
|
||||
if len(texts) > _MAX_INJECTIONS_PER_TURN:
|
||||
dropped = len(texts) - _MAX_INJECTIONS_PER_TURN
|
||||
injected_messages.append({"role": "user", "content": text})
|
||||
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
|
||||
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
|
||||
logger.warning(
|
||||
"Injection batch has {} messages, capping to {} ({} dropped)",
|
||||
len(texts), _MAX_INJECTIONS_PER_TURN, dropped,
|
||||
"Injection callback returned {} messages, capping to {} ({} dropped)",
|
||||
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
|
||||
)
|
||||
texts = texts[-_MAX_INJECTIONS_PER_TURN:]
|
||||
return texts
|
||||
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
|
||||
return injected_messages
|
||||
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
hook = spec.hook or AgentHook()
|
||||
@ -247,8 +300,7 @@ class AgentRunner:
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
for text in injections:
|
||||
messages.append({"role": "user", "content": text})
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
@ -340,8 +392,7 @@ class AgentRunner:
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
for text in injections:
|
||||
messages.append({"role": "user", "content": text})
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after final response ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@ -1633,7 +1634,7 @@ async def test_drain_injections_returns_empty_when_no_callback():
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_extracts_content_from_inbound_messages():
|
||||
"""Should extract .content from InboundMessage objects."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
@ -1655,12 +1656,15 @@ async def test_drain_injections_extracts_content_from_inbound_messages():
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == ["hello", "world"]
|
||||
assert result == [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "user", "content": "world"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_caps_at_max_and_logs_warning():
|
||||
"""When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept."""
|
||||
async def test_drain_injections_passes_limit_to_callback_when_supported():
|
||||
"""Limit-aware callbacks can preserve overflow in their own queue."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
@ -1668,14 +1672,16 @@ async def test_drain_injections_caps_at_max_and_logs_warning():
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
seen_limits: list[int] = []
|
||||
|
||||
msgs = [
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
|
||||
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
|
||||
]
|
||||
|
||||
async def cb():
|
||||
return msgs
|
||||
async def cb(*, limit: int):
|
||||
seen_limits.append(limit)
|
||||
return msgs[:limit]
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
@ -1683,10 +1689,12 @@ async def test_drain_injections_caps_at_max_and_logs_warning():
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert len(result) == _MAX_INJECTIONS_PER_TURN
|
||||
# Should keep the LAST _MAX_INJECTIONS_PER_TURN items
|
||||
assert result[0] == "msg3"
|
||||
assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}"
|
||||
assert seen_limits == [_MAX_INJECTIONS_PER_TURN]
|
||||
assert result == [
|
||||
{"role": "user", "content": "msg0"},
|
||||
{"role": "user", "content": "msg1"},
|
||||
{"role": "user", "content": "msg2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1715,7 +1723,7 @@ async def test_drain_injections_skips_empty_content():
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == ["valid"]
|
||||
assert result == [{"role": "user", "content": "valid"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1922,6 +1930,129 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup()
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_injected_followup_preserves_image_media(tmp_path):
|
||||
"""Mid-turn follow-ups with images should keep multimodal content."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
image_path = tmp_path / "followup.png"
|
||||
image_path.write_bytes(base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
|
||||
))
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
captured_messages: list[list[dict]] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append(list(messages))
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
pending_queue = asyncio.Queue()
|
||||
await pending_queue.put(InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u",
|
||||
chat_id="c",
|
||||
content="",
|
||||
media=[str(image_path)],
|
||||
))
|
||||
|
||||
final_content, _, _, _, had_injections = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
channel="cli",
|
||||
chat_id="c",
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
assert final_content == "second answer"
|
||||
assert had_injections is True
|
||||
assert call_count["n"] == 2
|
||||
injected_user_messages = [
|
||||
message for message in captured_messages[-1]
|
||||
if message.get("role") == "user" and isinstance(message.get("content"), list)
|
||||
]
|
||||
assert injected_user_messages
|
||||
assert any(
|
||||
block.get("type") == "image_url"
|
||||
for block in injected_user_messages[-1]["content"]
|
||||
if isinstance(block, dict)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_merges_multiple_injected_user_messages_without_losing_media():
|
||||
"""Multiple injected follow-ups should not create lossy consecutive user messages."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
captured_messages = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
async def inject_cb():
|
||||
if call_count["n"] == 1:
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
{"type": "text", "text": "look at this"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "and answer briefly"},
|
||||
]
|
||||
return []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.final_content == "second answer"
|
||||
assert call_count["n"] == 2
|
||||
second_call = captured_messages[-1]
|
||||
user_messages = [message for message in second_call if message.get("role") == "user"]
|
||||
assert len(user_messages) == 2
|
||||
injected = user_messages[-1]
|
||||
assert isinstance(injected["content"], list)
|
||||
assert any(
|
||||
block.get("type") == "image_url"
|
||||
for block in injected["content"]
|
||||
if isinstance(block, dict)
|
||||
)
|
||||
assert any(
|
||||
block.get("type") == "text" and block.get("text") == "and answer briefly"
|
||||
for block in injected["content"]
|
||||
if isinstance(block, dict)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injection_cycles_capped_at_max():
|
||||
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
|
||||
@ -2042,6 +2173,88 @@ async def test_followup_routed_to_pending_queue(tmp_path):
|
||||
assert queued_msg.session_key == UNIFIED_SESSION_KEY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path):
|
||||
"""Pending queue should leave overflow messages queued for later drains."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
captured_messages: list[list[dict]] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
pending_queue = asyncio.Queue()
|
||||
total_followups = _MAX_INJECTIONS_PER_TURN + 2
|
||||
for idx in range(total_followups):
|
||||
await pending_queue.put(InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u",
|
||||
chat_id="c",
|
||||
content=f"follow-up-{idx}",
|
||||
))
|
||||
|
||||
final_content, _, _, _, had_injections = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
channel="cli",
|
||||
chat_id="c",
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
assert final_content == "answer-3"
|
||||
assert had_injections is True
|
||||
assert call_count["n"] == 3
|
||||
flattened_user_content = "\n".join(
|
||||
message["content"]
|
||||
for message in captured_messages[-1]
|
||||
if message.get("role") == "user" and isinstance(message.get("content"), str)
|
||||
)
|
||||
for idx in range(total_followups):
|
||||
assert f"follow-up-{idx}" in flattened_user_content
|
||||
assert pending_queue.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_queue_full_falls_back_to_queued_task(tmp_path):
|
||||
"""QueueFull should preserve the message by dispatching a queued task."""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
loop._dispatch = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
pending = asyncio.Queue(maxsize=1)
|
||||
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued"))
|
||||
loop._pending_queues["cli:c"] = pending
|
||||
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
|
||||
await loop.bus.publish_inbound(msg)
|
||||
|
||||
deadline = time.time() + 2
|
||||
while loop._dispatch.await_count == 0 and time.time() < deadline:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
loop.stop()
|
||||
await asyncio.wait_for(run_task, timeout=2)
|
||||
|
||||
assert loop._dispatch.await_count == 1
|
||||
dispatched_msg = loop._dispatch.await_args.args[0]
|
||||
assert dispatched_msg.content == "follow-up"
|
||||
assert pending.qsize() == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
|
||||
"""Messages left in the pending queue after _dispatch are re-published to the bus.
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Test message tool suppress logic for final replies."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="First answer", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
pending_queue = asyncio.Queue()
|
||||
await pending_queue.put(
|
||||
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
|
||||
)
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
|
||||
result = await loop._process_message(msg, pending_queue=pending_queue)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert sent[0].content == "Tool reply"
|
||||
assert result is None
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user