mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-12 14:13:39 +00:00
feat(agent): mid-turn message injection for responsive follow-ups (#2985)
* feat(agent): add mid-turn message injection for responsive follow-ups
Allow user messages sent during an active agent turn to be injected
into the running LLM context instead of being queued behind a
per-session lock. Inspired by Claude Code's mid-turn queue drain
mechanism (query.ts:1547-1643).
Key design decisions:
- Messages are injected as natural user messages between iterations,
no tool cancellation or special system prompt needed
- Two drain checkpoints: after tool execution and after final LLM
response ("last-mile" to prevent dropping late arrivals)
- Bounded by MAX_INJECTION_CYCLES (5) to prevent consuming the
iteration budget on rapid follow-ups
- had_injections flag bypasses _sent_in_turn suppression so follow-up
responses are always delivered
Closes #1609
* fix(agent): harden mid-turn injection with streaming fix, bounded queue, and message safety
- Fix streaming protocol violation: Checkpoint 2 now checks for injections
BEFORE calling on_stream_end, passing resuming=True when injections found
so streaming channels (Feishu) don't prematurely finalize the card
- Bound pending queue to maxsize=20 with QueueFull handling
- Add warning log when injection batch exceeds _MAX_INJECTIONS_PER_TURN
- Re-publish leftover queue messages to bus in _dispatch finally block to
prevent silent message loss on early exit (max_iterations, tool_error, cancel)
- Fix PEP 8 blank line before dataclass and logger.info indentation
- Add 12 new tests covering drain, checkpoints, cycle cap, queue routing,
cleanup, and leftover re-publish
This commit is contained in:
parent
df6f9dd171
commit
bc4cc49a59
@ -237,6 +237,10 @@ class AgentLoop:
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._background_tasks: list[asyncio.Task] = []
|
||||
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||
# Per-session pending queues for mid-turn message injection.
|
||||
# When a session has an active task, new messages for that session
|
||||
# are routed here instead of creating a new task.
|
||||
self._pending_queues: dict[str, asyncio.Queue] = {}
|
||||
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
||||
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
||||
self._concurrency_gate: asyncio.Semaphore | None = (
|
||||
@ -348,13 +352,16 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict], str, bool]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
*on_stream*: called with each content delta during streaming.
|
||||
*on_stream_end(resuming)*: called when a streaming session finishes.
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
|
||||
Returns (final_content, tools_used, messages, stop_reason, had_injections).
|
||||
"""
|
||||
loop_hook = _LoopHook(
|
||||
self,
|
||||
@ -376,6 +383,18 @@ class AgentLoop:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
async def _drain_pending() -> list[InboundMessage]:
|
||||
"""Non-blocking drain of follow-up messages from the pending queue."""
|
||||
if pending_queue is None:
|
||||
return []
|
||||
items: list[InboundMessage] = []
|
||||
while True:
|
||||
try:
|
||||
items.append(pending_queue.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
return items
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
@ -392,13 +411,14 @@ class AgentLoop:
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
injection_callback=_drain_pending,
|
||||
))
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
elif result.stop_reason == "error":
|
||||
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
||||
return result.final_content, result.tools_used, result.messages
|
||||
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
@ -429,67 +449,112 @@ class AgentLoop:
|
||||
if result:
|
||||
await self.bus.publish_outbound(result)
|
||||
continue
|
||||
# If this session already has an active pending queue (i.e. a task
|
||||
# is processing this session), route the message there for mid-turn
|
||||
# injection instead of creating a competing task.
|
||||
if msg.session_key in self._pending_queues:
|
||||
try:
|
||||
self._pending_queues[msg.session_key].put_nowait(msg)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"Pending queue full for session {}, dropping follow-up",
|
||||
msg.session_key,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Routed follow-up message to pending queue for session {}",
|
||||
msg.session_key,
|
||||
)
|
||||
continue
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||
session_key = msg.session_key
|
||||
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
# Split one answer into distinct stream segments.
|
||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
# Register a pending queue so follow-up messages for this session are
|
||||
# routed here (mid-turn injection) instead of spawning a new task.
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
self._pending_queues[session_key] = pending
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
try:
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
# Split one answer into distinct stream segments.
|
||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
pending_queue=pending,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
finally:
|
||||
# Drain any messages still in the pending queue and re-publish
|
||||
# them to the bus so they are processed as fresh inbound messages
|
||||
# rather than silently lost.
|
||||
queue = self._pending_queues.pop(session_key, None)
|
||||
if queue is not None:
|
||||
leftover = 0
|
||||
while True:
|
||||
try:
|
||||
item = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await self.bus.publish_inbound(item)
|
||||
leftover += 1
|
||||
if leftover:
|
||||
logger.info(
|
||||
"Re-published {} leftover message(s) to bus for session {}",
|
||||
leftover, session_key,
|
||||
)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
@ -521,6 +586,7 @@ class AgentLoop:
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
@ -546,7 +612,7 @@ class AgentLoop:
|
||||
session_summary=pending,
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
@ -598,7 +664,7 @@ class AgentLoop:
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
on_progress=on_progress or _bus_progress,
|
||||
on_stream=on_stream,
|
||||
@ -606,6 +672,7 @@ class AgentLoop:
|
||||
session=session,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
if final_content is None or not final_content.strip():
|
||||
@ -616,8 +683,13 @@ class AgentLoop:
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
# When follow-up messages were injected mid-turn, the LLM's final
|
||||
# response addresses those follow-ups. Always send the response in
|
||||
# this case, even if MessageTool was used earlier in the turn — the
|
||||
# follow-up response is new content the user hasn't seen.
|
||||
if not had_injections:
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
@ -31,7 +31,11 @@ from nanobot.utils.runtime import (
|
||||
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
_MAX_EMPTY_RETRIES = 2
|
||||
_MAX_INJECTIONS_PER_TURN = 3
|
||||
_MAX_INJECTION_CYCLES = 5
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
"""Configuration for a single agent execution."""
|
||||
@ -56,6 +60,7 @@ class AgentRunSpec:
|
||||
provider_retry_mode: str = "standard"
|
||||
progress_callback: Any | None = None
|
||||
checkpoint_callback: Any | None = None
|
||||
injection_callback: Any | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -69,6 +74,7 @@ class AgentRunResult:
|
||||
stop_reason: str = "completed"
|
||||
error: str | None = None
|
||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||
had_injections: bool = False
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
@ -77,6 +83,38 @@ class AgentRunner:
|
||||
def __init__(self, provider: LLMProvider):
|
||||
self.provider = provider
|
||||
|
||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[str]:
|
||||
"""Drain pending user messages via the injection callback.
|
||||
|
||||
Returns all drained message contents (capped by
|
||||
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
|
||||
nothing to inject. Messages beyond the cap are logged so they
|
||||
are not silently lost.
|
||||
"""
|
||||
if spec.injection_callback is None:
|
||||
return []
|
||||
try:
|
||||
items = await spec.injection_callback()
|
||||
except Exception:
|
||||
logger.exception("injection_callback failed")
|
||||
return []
|
||||
if not items:
|
||||
return []
|
||||
# items are InboundMessage objects from _drain_pending
|
||||
texts: list[str] = []
|
||||
for item in items:
|
||||
text = getattr(item, "content", str(item))
|
||||
if text.strip():
|
||||
texts.append(text)
|
||||
if len(texts) > _MAX_INJECTIONS_PER_TURN:
|
||||
dropped = len(texts) - _MAX_INJECTIONS_PER_TURN
|
||||
logger.warning(
|
||||
"Injection batch has {} messages, capping to {} ({} dropped)",
|
||||
len(texts), _MAX_INJECTIONS_PER_TURN, dropped,
|
||||
)
|
||||
texts = texts[-_MAX_INJECTIONS_PER_TURN:]
|
||||
return texts
|
||||
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
hook = spec.hook or AgentHook()
|
||||
messages = list(spec.initial_messages)
|
||||
@ -88,6 +126,8 @@ class AgentRunner:
|
||||
tool_events: list[dict[str, str]] = []
|
||||
external_lookup_counts: dict[str, int] = {}
|
||||
empty_content_retries = 0
|
||||
had_injections = False
|
||||
injection_cycles = 0
|
||||
|
||||
for iteration in range(spec.max_iterations):
|
||||
try:
|
||||
@ -181,6 +221,18 @@ class AgentRunner:
|
||||
},
|
||||
)
|
||||
empty_content_retries = 0
|
||||
# Checkpoint 1: drain injections after tools, before next LLM call
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
for text in injections:
|
||||
messages.append({"role": "user", "content": text})
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -216,8 +268,29 @@ class AgentRunner:
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
clean = hook.finalize_content(context, response.content)
|
||||
|
||||
# Check for mid-turn injections BEFORE signaling stream end.
|
||||
# If injections are found we keep the stream alive (resuming=True)
|
||||
# so streaming channels don't prematurely finalize the card.
|
||||
_injected_after_final = False
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
_injected_after_final = True
|
||||
for text in injections:
|
||||
messages.append({"role": "user", "content": text})
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after final response ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.on_stream_end(context, resuming=_injected_after_final)
|
||||
|
||||
if _injected_after_final:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
if response.finish_reason == "error":
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
@ -283,6 +356,7 @@ class AgentRunner:
|
||||
stop_reason=stop_reason,
|
||||
error=error,
|
||||
tool_events=tool_events,
|
||||
had_injections=had_injections,
|
||||
)
|
||||
|
||||
def _build_request_kwargs(
|
||||
|
||||
@ -278,7 +278,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, tools_used, messages = await loop._run_agent_loop(
|
||||
content, tools_used, messages, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
@ -302,7 +302,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, _, _ = await loop._run_agent_loop(
|
||||
content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
@ -344,7 +344,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
content, tools_used, _ = await loop._run_agent_loop([])
|
||||
content, tools_used, _, _, _ = await loop._run_agent_loop([])
|
||||
assert content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
|
||||
@ -798,7 +798,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([])
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
@ -825,7 +825,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop(
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[],
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
@ -849,7 +849,7 @@ async def test_loop_retries_think_only_final_response(tmp_path):
|
||||
|
||||
loop.provider.chat_with_retry = chat_with_retry
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([])
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == "Recovered answer"
|
||||
assert call_count["n"] == 2
|
||||
@ -999,3 +999,412 @@ async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
|
||||
assert len(captured_usage) == 1
|
||||
assert captured_usage[0]["cached_tokens"] == 150
|
||||
|
||||
|
||||
# ── Mid-turn injection tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_returns_empty_when_no_callback():
|
||||
"""No injection_callback → empty list."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=None,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_extracts_content_from_inbound_messages():
|
||||
"""Should extract .content from InboundMessage objects."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
msgs = [
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"),
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"),
|
||||
]
|
||||
|
||||
async def cb():
|
||||
return msgs
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == ["hello", "world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_caps_at_max_and_logs_warning():
|
||||
"""When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
msgs = [
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
|
||||
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
|
||||
]
|
||||
|
||||
async def cb():
|
||||
return msgs
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert len(result) == _MAX_INJECTIONS_PER_TURN
|
||||
# Should keep the LAST _MAX_INJECTIONS_PER_TURN items
|
||||
assert result[0] == "msg3"
|
||||
assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_skips_empty_content():
|
||||
"""Messages with blank content should be filtered out."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
msgs = [
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""),
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "),
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"),
|
||||
]
|
||||
|
||||
async def cb():
|
||||
return msgs
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == ["valid"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_handles_callback_exception():
|
||||
"""If the callback raises, return empty list (error is logged)."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
async def cb():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint1_injects_after_tool_execution():
|
||||
"""Follow-up messages are injected after tool execution, before next LLM call."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
captured_messages = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append(list(messages))
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="using tool",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})],
|
||||
usage={},
|
||||
)
|
||||
return LLMResponse(content="final answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
# Put a follow-up message in the queue before the run starts
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "final answer"
|
||||
# The second call should have the injected user message
|
||||
assert call_count["n"] == 2
|
||||
last_messages = captured_messages[-1]
|
||||
injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
|
||||
"""After final response, if injections exist, stream_end should get resuming=True."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
stream_end_calls = []
|
||||
|
||||
class TrackingHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
stream_end_calls.append(resuming)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return content
|
||||
|
||||
async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
# Inject a follow-up that arrives during the first response
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=TrackingHook(),
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "second answer"
|
||||
assert call_count["n"] == 2
|
||||
# First stream_end should have resuming=True (because injections found)
|
||||
assert stream_end_calls[0] is True
|
||||
# Second (final) stream_end should have resuming=False
|
||||
assert stream_end_calls[-1] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injection_cycles_capped_at_max():
|
||||
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
drain_count = {"n": 0}
|
||||
|
||||
async def inject_cb():
|
||||
drain_count["n"] += 1
|
||||
# Only inject for the first _MAX_INJECTION_CYCLES drains
|
||||
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
|
||||
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
|
||||
return []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "start"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=20,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
# Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round
|
||||
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_injections_flag_is_false_by_default():
|
||||
"""had_injections should be False when no injection callback or no messages."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.had_injections is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_queue_cleanup_on_dispatch(tmp_path):
|
||||
"""_pending_queues should be cleaned up after _dispatch completes."""
|
||||
loop = _make_loop(tmp_path)
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_with_retry = chat_with_retry
|
||||
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello")
|
||||
# The queue should not exist before dispatch
|
||||
assert msg.session_key not in loop._pending_queues
|
||||
|
||||
await loop._dispatch(msg)
|
||||
|
||||
# The queue should be cleaned up after dispatch
|
||||
assert msg.session_key not in loop._pending_queues
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_followup_routed_to_pending_queue(tmp_path):
|
||||
"""When a session has an active dispatch, follow-up messages go to pending queue."""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
|
||||
# Simulate an active dispatch by manually adding a pending queue
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
loop._pending_queues["cli:c"] = pending
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
|
||||
|
||||
# Directly test the routing logic from run() — if session_key is in
|
||||
# _pending_queues, the message should be put into the queue.
|
||||
assert msg.session_key in loop._pending_queues
|
||||
loop._pending_queues[msg.session_key].put_nowait(msg)
|
||||
|
||||
assert not pending.empty()
|
||||
queued_msg = pending.get_nowait()
|
||||
assert queued_msg.content == "follow-up"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
|
||||
"""Messages left in the pending queue after _dispatch are re-published to the bus.
|
||||
|
||||
This tests the finally-block cleanup that prevents message loss when
|
||||
the runner exits early (e.g., max_iterations, tool_error) with messages
|
||||
still in the queue.
|
||||
"""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
bus = loop.bus
|
||||
|
||||
# Simulate a completed dispatch by manually registering a queue
|
||||
# with leftover messages, then running the cleanup logic directly.
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
session_key = "cli:c"
|
||||
loop._pending_queues[session_key] = pending
|
||||
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1"))
|
||||
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2"))
|
||||
|
||||
# Execute the cleanup logic from the finally block
|
||||
queue = loop._pending_queues.pop(session_key, None)
|
||||
assert queue is not None
|
||||
leftover = 0
|
||||
while True:
|
||||
try:
|
||||
item = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await bus.publish_inbound(item)
|
||||
leftover += 1
|
||||
|
||||
assert leftover == 2
|
||||
|
||||
# Verify the messages are now on the bus
|
||||
msgs = []
|
||||
while not bus.inbound.empty():
|
||||
msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5))
|
||||
contents = [m.content for m in msgs]
|
||||
assert "leftover-1" in contents
|
||||
assert "leftover-2" in contents
|
||||
|
||||
@ -107,7 +107,7 @@ class TestMessageToolSuppressLogic:
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user