mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-14 15:09:55 +00:00
feat(agent): enhance session key handling for follow-up messages
This commit is contained in:
parent
36d2a11e73
commit
f6c39ec946
@ -324,6 +324,12 @@ class AgentLoop:
|
||||
|
||||
return format_tool_hints(tool_calls)
|
||||
|
||||
def _effective_session_key(self, msg: InboundMessage) -> str:
|
||||
"""Return the session key used for task routing and mid-turn injections."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
return UNIFIED_SESSION_KEY
|
||||
return msg.session_key
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
@ -430,30 +436,32 @@ class AgentLoop:
|
||||
if result:
|
||||
await self.bus.publish_outbound(result)
|
||||
continue
|
||||
effective_key = self._effective_session_key(msg)
|
||||
# If this session already has an active pending queue (i.e. a task
|
||||
# is processing this session), route the message there for mid-turn
|
||||
# injection instead of creating a competing task.
|
||||
if msg.session_key in self._pending_queues:
|
||||
if effective_key in self._pending_queues:
|
||||
pending_msg = msg
|
||||
if effective_key != msg.session_key:
|
||||
pending_msg = dataclasses.replace(
|
||||
msg,
|
||||
session_key_override=effective_key,
|
||||
)
|
||||
try:
|
||||
self._pending_queues[msg.session_key].put_nowait(msg)
|
||||
self._pending_queues[effective_key].put_nowait(pending_msg)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"Pending queue full for session {}, dropping follow-up",
|
||||
msg.session_key,
|
||||
effective_key,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Routed follow-up message to pending queue for session {}",
|
||||
msg.session_key,
|
||||
effective_key,
|
||||
)
|
||||
continue
|
||||
# Compute the effective session key before dispatching
|
||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||
effective_key = (
|
||||
UNIFIED_SESSION_KEY
|
||||
if self._unified_session and not msg.session_key_override
|
||||
else msg.session_key
|
||||
)
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||
task.add_done_callback(
|
||||
@ -465,9 +473,9 @@ class AgentLoop:
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
|
||||
session_key = msg.session_key
|
||||
session_key = self._effective_session_key(msg)
|
||||
if session_key != msg.session_key:
|
||||
msg = dataclasses.replace(msg, session_key_override=session_key)
|
||||
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
|
||||
|
||||
@ -309,6 +309,14 @@ class AgentRunner:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
assistant_message: dict[str, Any] | None = None
|
||||
if response.finish_reason != "error" and not is_blank_text(clean):
|
||||
assistant_message = build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
|
||||
# Check for mid-turn injections BEFORE signaling stream end.
|
||||
# If injections are found we keep the stream alive (resuming=True)
|
||||
# so streaming channels don't prematurely finalize the card.
|
||||
@ -319,6 +327,19 @@ class AgentRunner:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
_injected_after_final = True
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
for text in injections:
|
||||
messages.append({"role": "user", "content": text})
|
||||
logger.info(
|
||||
@ -354,7 +375,7 @@ class AgentRunner:
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
messages.append(assistant_message or build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
|
||||
@ -1862,6 +1862,66 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
|
||||
assert stream_end_calls[-1] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint2_preserves_final_response_in_history_before_followup():
|
||||
"""A follow-up injected after a final answer must still see that answer in history."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
captured_messages = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.final_content == "second answer"
|
||||
assert call_count["n"] == 2
|
||||
assert captured_messages[-1] == [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "first answer"},
|
||||
{"role": "user", "content": "follow-up question"},
|
||||
]
|
||||
assert [
|
||||
{"role": message["role"], "content": message["content"]}
|
||||
for message in result.messages
|
||||
if message.get("role") == "assistant"
|
||||
] == [
|
||||
{"role": "assistant", "content": "first answer"},
|
||||
{"role": "assistant", "content": "second answer"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injection_cycles_capped_at_max():
|
||||
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
|
||||
@ -1953,25 +2013,33 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_followup_routed_to_pending_queue(tmp_path):
|
||||
"""When a session has an active dispatch, follow-up messages go to pending queue."""
|
||||
"""Unified-session follow-ups should route into the active pending queue."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
loop._unified_session = True
|
||||
loop._dispatch = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
# Simulate an active dispatch by manually adding a pending queue
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
loop._pending_queues["cli:c"] = pending
|
||||
loop._pending_queues[UNIFIED_SESSION_KEY] = pending
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up")
|
||||
await loop.bus.publish_inbound(msg)
|
||||
|
||||
# Directly test the routing logic from run() — if session_key is in
|
||||
# _pending_queues, the message should be put into the queue.
|
||||
assert msg.session_key in loop._pending_queues
|
||||
loop._pending_queues[msg.session_key].put_nowait(msg)
|
||||
deadline = time.time() + 2
|
||||
while pending.empty() and time.time() < deadline:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
loop.stop()
|
||||
await asyncio.wait_for(run_task, timeout=2)
|
||||
|
||||
assert loop._dispatch.await_count == 0
|
||||
assert not pending.empty()
|
||||
queued_msg = pending.get_nowait()
|
||||
assert queued_msg.content == "follow-up"
|
||||
assert queued_msg.session_key == UNIFIED_SESSION_KEY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user