feat(agent): enhance session key handling for follow-up messages

This commit is contained in:
Xubin Ren 2026-04-11 13:17:28 +00:00 committed by Xubin Ren
parent 36d2a11e73
commit f6c39ec946
3 changed files with 118 additions and 21 deletions

View File

@ -324,6 +324,12 @@ class AgentLoop:
return format_tool_hints(tool_calls)
def _effective_session_key(self, msg: InboundMessage) -> str:
"""Return the session key used for task routing and mid-turn injections."""
if self._unified_session and not msg.session_key_override:
return UNIFIED_SESSION_KEY
return msg.session_key
async def _run_agent_loop(
self,
initial_messages: list[dict],
@ -430,30 +436,32 @@ class AgentLoop:
if result:
await self.bus.publish_outbound(result)
continue
effective_key = self._effective_session_key(msg)
# If this session already has an active pending queue (i.e. a task
# is processing this session), route the message there for mid-turn
# injection instead of creating a competing task.
if msg.session_key in self._pending_queues:
if effective_key in self._pending_queues:
pending_msg = msg
if effective_key != msg.session_key:
pending_msg = dataclasses.replace(
msg,
session_key_override=effective_key,
)
try:
self._pending_queues[msg.session_key].put_nowait(msg)
self._pending_queues[effective_key].put_nowait(pending_msg)
except asyncio.QueueFull:
logger.warning(
"Pending queue full for session {}, dropping follow-up",
msg.session_key,
effective_key,
)
else:
logger.info(
"Routed follow-up message to pending queue for session {}",
msg.session_key,
effective_key,
)
continue
# Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled
effective_key = (
UNIFIED_SESSION_KEY
if self._unified_session and not msg.session_key_override
else msg.session_key
)
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(effective_key, []).append(task)
task.add_done_callback(
@ -465,9 +473,9 @@ class AgentLoop:
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent."""
if self._unified_session and not msg.session_key_override:
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
session_key = msg.session_key
session_key = self._effective_session_key(msg)
if session_key != msg.session_key:
msg = dataclasses.replace(msg, session_key_override=session_key)
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext()

View File

@ -309,6 +309,14 @@ class AgentRunner:
await hook.after_iteration(context)
continue
assistant_message: dict[str, Any] | None = None
if response.finish_reason != "error" and not is_blank_text(clean):
assistant_message = build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
)
# Check for mid-turn injections BEFORE signaling stream end.
# If injections are found we keep the stream alive (resuming=True)
# so streaming channels don't prematurely finalize the card.
@ -319,6 +327,19 @@ class AgentRunner:
had_injections = True
injection_cycles += 1
_injected_after_final = True
if assistant_message is not None:
messages.append(assistant_message)
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
for text in injections:
messages.append({"role": "user", "content": text})
logger.info(
@ -354,7 +375,7 @@ class AgentRunner:
await hook.after_iteration(context)
break
messages.append(build_assistant_message(
messages.append(assistant_message or build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,

View File

@ -1862,6 +1862,66 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
assert stream_end_calls[-1] is False
@pytest.mark.asyncio
async def test_checkpoint2_preserves_final_response_in_history_before_followup():
"""A follow-up injected after a final answer must still see that answer in history."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.final_content == "second answer"
assert call_count["n"] == 2
assert captured_messages[-1] == [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "first answer"},
{"role": "user", "content": "follow-up question"},
]
assert [
{"role": message["role"], "content": message["content"]}
for message in result.messages
if message.get("role") == "assistant"
] == [
{"role": "assistant", "content": "first answer"},
{"role": "assistant", "content": "second answer"},
]
@pytest.mark.asyncio
async def test_injection_cycles_capped_at_max():
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
@ -1953,25 +2013,33 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path):
@pytest.mark.asyncio
async def test_followup_routed_to_pending_queue(tmp_path):
"""When a session has an active dispatch, follow-up messages go to pending queue."""
"""Unified-session follow-ups should route into the active pending queue."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
loop._unified_session = True
loop._dispatch = AsyncMock() # type: ignore[method-assign]
# Simulate an active dispatch by manually adding a pending queue
pending = asyncio.Queue(maxsize=20)
loop._pending_queues["cli:c"] = pending
loop._pending_queues[UNIFIED_SESSION_KEY] = pending
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
run_task = asyncio.create_task(loop.run())
msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up")
await loop.bus.publish_inbound(msg)
# Directly test the routing logic from run() — if session_key is in
# _pending_queues, the message should be put into the queue.
assert msg.session_key in loop._pending_queues
loop._pending_queues[msg.session_key].put_nowait(msg)
deadline = time.time() + 2
while pending.empty() and time.time() < deadline:
await asyncio.sleep(0.01)
loop.stop()
await asyncio.wait_for(run_task, timeout=2)
assert loop._dispatch.await_count == 0
assert not pending.empty()
queued_msg = pending.get_nowait()
assert queued_msg.content == "follow-up"
assert queued_msg.session_key == UNIFIED_SESSION_KEY
@pytest.mark.asyncio