diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 96b5b30c6..0031c90c5 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -129,6 +129,7 @@ class AgentLoop: """ _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" + _PENDING_USER_TURN_KEY = "pending_user_turn" def __init__( self, @@ -618,6 +619,8 @@ class AgentLoop: session = self.sessions.get_or_create(key) if self._restore_runtime_checkpoint(session): self.sessions.save(session) + if self._restore_pending_user_turn(session): + self.sessions.save(session) session, pending = self.auto_compact.prepare_session(session, key) @@ -653,6 +656,8 @@ class AgentLoop: session = self.sessions.get_or_create(key) if self._restore_runtime_checkpoint(session): self.sessions.save(session) + if self._restore_pending_user_turn(session): + self.sessions.save(session) session, pending = self.auto_compact.prepare_session(session, key) @@ -702,6 +707,7 @@ class AgentLoop: user_persisted_early = False if isinstance(msg.content, str) and msg.content.strip(): session.add_message("user", msg.content) + self._mark_pending_user_turn(session) self.sessions.save(session) user_persisted_early = True @@ -723,6 +729,7 @@ class AgentLoop: # Skip the already-persisted user message when saving the turn save_skip = 1 + len(history) + (1 if user_persisted_early else 0) self._save_turn(session, all_msgs, save_skip) + self._clear_pending_user_turn(session) self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) @@ -840,6 +847,12 @@ class AgentLoop: session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload self.sessions.save(session) + def _mark_pending_user_turn(self, session: Session) -> None: + session.metadata[self._PENDING_USER_TURN_KEY] = True + + def _clear_pending_user_turn(self, session: Session) -> None: + session.metadata.pop(self._PENDING_USER_TURN_KEY, None) + def _clear_runtime_checkpoint(self, session: Session) -> None: if self._RUNTIME_CHECKPOINT_KEY in session.metadata: session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None) @@ -906,9 +919,30 @@ class AgentLoop: break session.messages.extend(restored_messages[overlap:]) + self._clear_pending_user_turn(session) self._clear_runtime_checkpoint(session) return True + def _restore_pending_user_turn(self, session: Session) -> bool: + """Close a turn that only persisted the user message before crashing.""" + from datetime import datetime + + if not session.metadata.get(self._PENDING_USER_TURN_KEY): + return False + + if session.messages and session.messages[-1].get("role") == "user": + session.messages.append( + { + "role": "assistant", + "content": "Error: Task interrupted before a response was generated.", + "timestamp": datetime.now().isoformat(), + } + ) + session.updated_at = datetime.now() + + self._clear_pending_user_turn(session) + return True + async def process_direct( self, content: str, diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index c499282ab..c965ccd8c 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -229,6 +229,7 @@ async def test_process_message_persists_user_message_before_turn_completes(tmp_p persisted = loop.sessions.get_or_create("feishu:c1") assert [m["role"] for m in persisted.messages] == ["user"] assert persisted.messages[0]["content"] == "persist me" + assert persisted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True assert persisted.updated_at >= persisted.created_at @@ -262,3 +263,48 @@ async def test_process_message_does_not_duplicate_early_persisted_user_message(t {"role": "user", "content": "hello"}, {"role": "assistant", "content": "done"}, ] + assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata + + +@pytest.mark.asyncio +async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + loop.provider.chat_with_retry = AsyncMock(return_value=MagicMock()) # unused because _run_agent_loop is stubbed + + session = loop.sessions.get_or_create("feishu:c3") + session.add_message("user", "old question") + session.metadata[AgentLoop._PENDING_USER_TURN_KEY] = True + loop.sessions.save(session) + + loop._run_agent_loop = AsyncMock(return_value=( + "new answer", + None, + [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "Error: Task interrupted before a response was generated."}, + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ], + "stop", + False, + )) # type: ignore[method-assign] + + result = await loop._process_message( + InboundMessage(channel="feishu", sender_id="u1", chat_id="c3", content="new question") + ) + + assert result is not None + assert result.content == "new answer" + session = loop.sessions.get_or_create("feishu:c3") + assert [ + {k: v for k, v in m.items() if k in {"role", "content"}} + for m in session.messages + ] == [ + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "Error: Task interrupted before a response was generated."}, + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ] + assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata