mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-27 05:15:51 +00:00
feat(agent): enhance session key handling for follow-up messages
This commit is contained in:
parent
36d2a11e73
commit
f6c39ec946
@ -324,6 +324,12 @@ class AgentLoop:
|
|||||||
|
|
||||||
return format_tool_hints(tool_calls)
|
return format_tool_hints(tool_calls)
|
||||||
|
|
||||||
|
def _effective_session_key(self, msg: InboundMessage) -> str:
|
||||||
|
"""Return the session key used for task routing and mid-turn injections."""
|
||||||
|
if self._unified_session and not msg.session_key_override:
|
||||||
|
return UNIFIED_SESSION_KEY
|
||||||
|
return msg.session_key
|
||||||
|
|
||||||
async def _run_agent_loop(
|
async def _run_agent_loop(
|
||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
@ -430,30 +436,32 @@ class AgentLoop:
|
|||||||
if result:
|
if result:
|
||||||
await self.bus.publish_outbound(result)
|
await self.bus.publish_outbound(result)
|
||||||
continue
|
continue
|
||||||
|
effective_key = self._effective_session_key(msg)
|
||||||
# If this session already has an active pending queue (i.e. a task
|
# If this session already has an active pending queue (i.e. a task
|
||||||
# is processing this session), route the message there for mid-turn
|
# is processing this session), route the message there for mid-turn
|
||||||
# injection instead of creating a competing task.
|
# injection instead of creating a competing task.
|
||||||
if msg.session_key in self._pending_queues:
|
if effective_key in self._pending_queues:
|
||||||
|
pending_msg = msg
|
||||||
|
if effective_key != msg.session_key:
|
||||||
|
pending_msg = dataclasses.replace(
|
||||||
|
msg,
|
||||||
|
session_key_override=effective_key,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
self._pending_queues[msg.session_key].put_nowait(msg)
|
self._pending_queues[effective_key].put_nowait(pending_msg)
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Pending queue full for session {}, dropping follow-up",
|
"Pending queue full for session {}, dropping follow-up",
|
||||||
msg.session_key,
|
effective_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Routed follow-up message to pending queue for session {}",
|
"Routed follow-up message to pending queue for session {}",
|
||||||
msg.session_key,
|
effective_key,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
# Compute the effective session key before dispatching
|
# Compute the effective session key before dispatching
|
||||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||||
effective_key = (
|
|
||||||
UNIFIED_SESSION_KEY
|
|
||||||
if self._unified_session and not msg.session_key_override
|
|
||||||
else msg.session_key
|
|
||||||
)
|
|
||||||
task = asyncio.create_task(self._dispatch(msg))
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||||
task.add_done_callback(
|
task.add_done_callback(
|
||||||
@ -465,9 +473,9 @@ class AgentLoop:
|
|||||||
|
|
||||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
"""Process a message: per-session serial, cross-session concurrent."""
|
"""Process a message: per-session serial, cross-session concurrent."""
|
||||||
if self._unified_session and not msg.session_key_override:
|
session_key = self._effective_session_key(msg)
|
||||||
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
|
if session_key != msg.session_key:
|
||||||
session_key = msg.session_key
|
msg = dataclasses.replace(msg, session_key_override=session_key)
|
||||||
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||||
gate = self._concurrency_gate or nullcontext()
|
gate = self._concurrency_gate or nullcontext()
|
||||||
|
|
||||||
|
|||||||
@ -309,6 +309,14 @@ class AgentRunner:
|
|||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
assistant_message: dict[str, Any] | None = None
|
||||||
|
if response.finish_reason != "error" and not is_blank_text(clean):
|
||||||
|
assistant_message = build_assistant_message(
|
||||||
|
clean,
|
||||||
|
reasoning_content=response.reasoning_content,
|
||||||
|
thinking_blocks=response.thinking_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
# Check for mid-turn injections BEFORE signaling stream end.
|
# Check for mid-turn injections BEFORE signaling stream end.
|
||||||
# If injections are found we keep the stream alive (resuming=True)
|
# If injections are found we keep the stream alive (resuming=True)
|
||||||
# so streaming channels don't prematurely finalize the card.
|
# so streaming channels don't prematurely finalize the card.
|
||||||
@ -319,6 +327,19 @@ class AgentRunner:
|
|||||||
had_injections = True
|
had_injections = True
|
||||||
injection_cycles += 1
|
injection_cycles += 1
|
||||||
_injected_after_final = True
|
_injected_after_final = True
|
||||||
|
if assistant_message is not None:
|
||||||
|
messages.append(assistant_message)
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "final_response",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": assistant_message,
|
||||||
|
"completed_tool_results": [],
|
||||||
|
"pending_tool_calls": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
for text in injections:
|
for text in injections:
|
||||||
messages.append({"role": "user", "content": text})
|
messages.append({"role": "user", "content": text})
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -354,7 +375,7 @@ class AgentRunner:
|
|||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
break
|
break
|
||||||
|
|
||||||
messages.append(build_assistant_message(
|
messages.append(assistant_message or build_assistant_message(
|
||||||
clean,
|
clean,
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
|
|||||||
@ -1862,6 +1862,66 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
|
|||||||
assert stream_end_calls[-1] is False
|
assert stream_end_calls[-1] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint2_preserves_final_response_in_history_before_followup():
|
||||||
|
"""A follow-up injected after a final answer must still see that answer in history."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
captured_messages = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
captured_messages.append([dict(message) for message in messages])
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||||
|
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
injection_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def inject_cb():
|
||||||
|
items = []
|
||||||
|
while not injection_queue.empty():
|
||||||
|
items.append(await injection_queue.get())
|
||||||
|
return items
|
||||||
|
|
||||||
|
await injection_queue.put(
|
||||||
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hello"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=5,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
injection_callback=inject_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "second answer"
|
||||||
|
assert call_count["n"] == 2
|
||||||
|
assert captured_messages[-1] == [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "first answer"},
|
||||||
|
{"role": "user", "content": "follow-up question"},
|
||||||
|
]
|
||||||
|
assert [
|
||||||
|
{"role": message["role"], "content": message["content"]}
|
||||||
|
for message in result.messages
|
||||||
|
if message.get("role") == "assistant"
|
||||||
|
] == [
|
||||||
|
{"role": "assistant", "content": "first answer"},
|
||||||
|
{"role": "assistant", "content": "second answer"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_injection_cycles_capped_at_max():
|
async def test_injection_cycles_capped_at_max():
|
||||||
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
|
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
|
||||||
@ -1953,25 +2013,33 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_followup_routed_to_pending_queue(tmp_path):
|
async def test_followup_routed_to_pending_queue(tmp_path):
|
||||||
"""When a session has an active dispatch, follow-up messages go to pending queue."""
|
"""Unified-session follow-ups should route into the active pending queue."""
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
|
loop._unified_session = True
|
||||||
|
loop._dispatch = AsyncMock() # type: ignore[method-assign]
|
||||||
|
|
||||||
# Simulate an active dispatch by manually adding a pending queue
|
|
||||||
pending = asyncio.Queue(maxsize=20)
|
pending = asyncio.Queue(maxsize=20)
|
||||||
loop._pending_queues["cli:c"] = pending
|
loop._pending_queues[UNIFIED_SESSION_KEY] = pending
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
|
run_task = asyncio.create_task(loop.run())
|
||||||
|
msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up")
|
||||||
|
await loop.bus.publish_inbound(msg)
|
||||||
|
|
||||||
# Directly test the routing logic from run() — if session_key is in
|
deadline = time.time() + 2
|
||||||
# _pending_queues, the message should be put into the queue.
|
while pending.empty() and time.time() < deadline:
|
||||||
assert msg.session_key in loop._pending_queues
|
await asyncio.sleep(0.01)
|
||||||
loop._pending_queues[msg.session_key].put_nowait(msg)
|
|
||||||
|
|
||||||
|
loop.stop()
|
||||||
|
await asyncio.wait_for(run_task, timeout=2)
|
||||||
|
|
||||||
|
assert loop._dispatch.await_count == 0
|
||||||
assert not pending.empty()
|
assert not pending.empty()
|
||||||
queued_msg = pending.get_nowait()
|
queued_msg = pending.get_nowait()
|
||||||
assert queued_msg.content == "follow-up"
|
assert queued_msg.content == "follow-up"
|
||||||
|
assert queued_msg.session_key == UNIFIED_SESSION_KEY
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user