mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
fix(agent): address session and streaming concurrency bugs
This commit is contained in:
parent
1a4ae8994d
commit
0df60416ba
@ -812,16 +812,16 @@ class AgentLoop:
|
|||||||
logger.warning("Error consuming inbound message: {}, continuing...", e)
|
logger.warning("Error consuming inbound message: {}, continuing...", e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
raw = msg.content.strip()
|
||||||
|
effective_key = self._effective_session_key(msg)
|
||||||
if await agent_context.handle_runtime_control(self, msg, self.tools):
|
if await agent_context.handle_runtime_control(self, msg, self.tools):
|
||||||
continue
|
continue
|
||||||
raw = msg.content.strip()
|
|
||||||
if self.commands.is_priority(raw):
|
if self.commands.is_priority(raw):
|
||||||
await self._dispatch_command_inline(
|
await self._dispatch_command_inline(
|
||||||
msg, msg.session_key, raw,
|
msg, effective_key, raw,
|
||||||
self.commands.dispatch_priority,
|
self.commands.dispatch_priority,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
effective_key = self._effective_session_key(msg)
|
|
||||||
# If this session already has an active pending queue (i.e. a task
|
# If this session already has an active pending queue (i.e. a task
|
||||||
# is processing this session), route the message there for mid-turn
|
# is processing this session), route the message there for mid-turn
|
||||||
# injection instead of creating a competing task.
|
# injection instead of creating a competing task.
|
||||||
@ -872,13 +872,13 @@ class AgentLoop:
|
|||||||
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||||
gate = self._concurrency_gate or nullcontext()
|
gate = self._concurrency_gate or nullcontext()
|
||||||
|
|
||||||
# Register a pending queue so follow-up messages for this session are
|
pending: asyncio.Queue | None = None
|
||||||
# routed here (mid-turn injection) instead of spawning a new task.
|
|
||||||
pending = asyncio.Queue(maxsize=20)
|
|
||||||
self._pending_queues[session_key] = pending
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with lock, gate:
|
async with lock, gate:
|
||||||
|
# Only the task that owns the session lock may publish the
|
||||||
|
# active mid-turn injection queue for this session.
|
||||||
|
pending = asyncio.Queue(maxsize=20)
|
||||||
|
self._pending_queues[session_key] = pending
|
||||||
try:
|
try:
|
||||||
on_stream = on_stream_end = None
|
on_stream = on_stream_end = None
|
||||||
if msg.metadata.get("_wants_stream"):
|
if msg.metadata.get("_wants_stream"):
|
||||||
@ -962,28 +962,39 @@ class AgentLoop:
|
|||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="Sorry, I encountered an error.",
|
content="Sorry, I encountered an error.",
|
||||||
))
|
))
|
||||||
|
finally:
|
||||||
|
# Drain any messages still in the pending queue and re-publish
|
||||||
|
# them to the bus so they are processed as fresh inbound messages
|
||||||
|
# rather than silently lost. Only remove our own queue; a
|
||||||
|
# later task waiting on the lock must not be able to steal
|
||||||
|
# cleanup ownership.
|
||||||
|
queue = None
|
||||||
|
if self._pending_queues.get(session_key) is pending:
|
||||||
|
queue = self._pending_queues.pop(session_key, None)
|
||||||
|
else:
|
||||||
|
queue = pending
|
||||||
|
if queue is not None:
|
||||||
|
leftover = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
await self.bus.publish_inbound(item)
|
||||||
|
leftover += 1
|
||||||
|
if leftover:
|
||||||
|
logger.info(
|
||||||
|
"Re-published {} leftover message(s) to bus for session {}",
|
||||||
|
leftover, session_key,
|
||||||
|
)
|
||||||
|
await self._webui_turns.publish_run_status(msg, "idle")
|
||||||
|
self._pending_turn_latency_ms.pop(session_key, None)
|
||||||
|
self._webui_turns.discard(session_key)
|
||||||
finally:
|
finally:
|
||||||
# Drain any messages still in the pending queue and re-publish
|
if pending is None:
|
||||||
# them to the bus so they are processed as fresh inbound messages
|
await self._webui_turns.publish_run_status(msg, "idle")
|
||||||
# rather than silently lost.
|
self._pending_turn_latency_ms.pop(session_key, None)
|
||||||
queue = self._pending_queues.pop(session_key, None)
|
self._webui_turns.discard(session_key)
|
||||||
if queue is not None:
|
|
||||||
leftover = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
item = queue.get_nowait()
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
break
|
|
||||||
await self.bus.publish_inbound(item)
|
|
||||||
leftover += 1
|
|
||||||
if leftover:
|
|
||||||
logger.info(
|
|
||||||
"Re-published {} leftover message(s) to bus for session {}",
|
|
||||||
leftover, session_key,
|
|
||||||
)
|
|
||||||
await self._webui_turns.publish_run_status(msg, "idle")
|
|
||||||
self._pending_turn_latency_ms.pop(session_key, None)
|
|
||||||
self._webui_turns.discard(session_key)
|
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Drain pending background archives, then close MCP connections."""
|
"""Drain pending background archives, then close MCP connections."""
|
||||||
|
|||||||
@ -1273,7 +1273,13 @@ class AgentRunner:
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
|
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
|
||||||
remaining_budget = max(128, budget - system_tokens)
|
fixed_tokens, _ = estimate_prompt_tokens_chain(
|
||||||
|
self.provider,
|
||||||
|
spec.model,
|
||||||
|
system_messages,
|
||||||
|
spec.tools.get_definitions(),
|
||||||
|
)
|
||||||
|
remaining_budget = max(0, budget - max(system_tokens, fixed_tokens))
|
||||||
kept: list[dict[str, Any]] = []
|
kept: list[dict[str, Any]] = []
|
||||||
kept_tokens = 0
|
kept_tokens = 0
|
||||||
for message in reversed(non_system):
|
for message in reversed(non_system):
|
||||||
|
|||||||
@ -16,6 +16,7 @@ There is **no** sub-agent orchestrator and **no** special WebSocket ``agent_ui``
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@ -45,15 +46,19 @@ class _GoalToolsMixin(ContextAware):
|
|||||||
def __init__(self, sessions: SessionManager, bus: Any | None = None) -> None:
|
def __init__(self, sessions: SessionManager, bus: Any | None = None) -> None:
|
||||||
self._sessions = sessions
|
self._sessions = sessions
|
||||||
self._bus = bus
|
self._bus = bus
|
||||||
self._request_ctx: RequestContext | None = None
|
self._request_ctx: ContextVar[RequestContext | None] = ContextVar(
|
||||||
|
f"{self.__class__.__name__}_request_ctx",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
def set_context(self, ctx: RequestContext) -> None:
|
def set_context(self, ctx: RequestContext) -> None:
|
||||||
self._request_ctx = ctx
|
self._request_ctx.set(ctx)
|
||||||
|
|
||||||
def _session(self):
|
def _session(self):
|
||||||
if self._request_ctx is None:
|
request_ctx = self._request_ctx.get()
|
||||||
|
if request_ctx is None:
|
||||||
return None
|
return None
|
||||||
key = self._request_ctx.session_key
|
key = request_ctx.session_key
|
||||||
if not key:
|
if not key:
|
||||||
return None
|
return None
|
||||||
return self._sessions.get_or_create(key)
|
return self._sessions.get_or_create(key)
|
||||||
@ -61,7 +66,7 @@ class _GoalToolsMixin(ContextAware):
|
|||||||
async def _publish_goal_state_ws(self, metadata: dict[str, Any]) -> None:
|
async def _publish_goal_state_ws(self, metadata: dict[str, Any]) -> None:
|
||||||
"""Fan-out authoritative goal snapshot for this WebSocket chat only."""
|
"""Fan-out authoritative goal snapshot for this WebSocket chat only."""
|
||||||
bus = self._bus
|
bus = self._bus
|
||||||
rc = self._request_ctx
|
rc = self._request_ctx.get()
|
||||||
if bus is None or rc is None or rc.channel != "websocket":
|
if bus is None or rc is None or rc.channel != "websocket":
|
||||||
return
|
return
|
||||||
cid = (rc.chat_id or "").strip()
|
cid = (rc.chat_id or "").strip()
|
||||||
@ -224,4 +229,3 @@ class CompleteGoalTool(Tool, _GoalToolsMixin):
|
|||||||
if tail:
|
if tail:
|
||||||
return f"Goal marked complete ({ended}). Recap:\n{tail}"
|
return f"Goal marked complete ({ended}). Recap:\n{tail}"
|
||||||
return f"Goal marked complete ({ended})."
|
return f"Goal marked complete ({ended})."
|
||||||
|
|
||||||
|
|||||||
@ -123,7 +123,7 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
|
|||||||
"""Cancel all active tasks and subagents for the session."""
|
"""Cancel all active tasks and subagents for the session."""
|
||||||
loop = ctx.loop
|
loop = ctx.loop
|
||||||
msg = ctx.msg
|
msg = ctx.msg
|
||||||
total = await loop._cancel_active_tasks(msg.session_key)
|
total = await loop._cancel_active_tasks(ctx.key)
|
||||||
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||||
|
|||||||
@ -557,11 +557,20 @@ class LLMProvider(ABC):
|
|||||||
if reasoning_effort is self._SENTINEL:
|
if reasoning_effort is self._SENTINEL:
|
||||||
reasoning_effort = self.generation.reasoning_effort
|
reasoning_effort = self.generation.reasoning_effort
|
||||||
|
|
||||||
|
has_streamed_content = False
|
||||||
|
|
||||||
|
async def _tracking_delta(text: str) -> None:
|
||||||
|
nonlocal has_streamed_content
|
||||||
|
if text:
|
||||||
|
has_streamed_content = True
|
||||||
|
if on_content_delta:
|
||||||
|
await on_content_delta(text)
|
||||||
|
|
||||||
kw: dict[str, Any] = dict(
|
kw: dict[str, Any] = dict(
|
||||||
messages=messages, tools=tools, model=model,
|
messages=messages, tools=tools, model=model,
|
||||||
max_tokens=max_tokens, temperature=temperature,
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
on_content_delta=on_content_delta,
|
on_content_delta=_tracking_delta if on_content_delta is not None else None,
|
||||||
on_thinking_delta=on_thinking_delta,
|
on_thinking_delta=on_thinking_delta,
|
||||||
on_tool_call_delta=on_tool_call_delta,
|
on_tool_call_delta=on_tool_call_delta,
|
||||||
)
|
)
|
||||||
@ -571,6 +580,7 @@ class LLMProvider(ABC):
|
|||||||
messages,
|
messages,
|
||||||
retry_mode=retry_mode,
|
retry_mode=retry_mode,
|
||||||
on_retry_wait=on_retry_wait,
|
on_retry_wait=on_retry_wait,
|
||||||
|
should_retry_guard=lambda: not has_streamed_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_with_retry(
|
async def chat_with_retry(
|
||||||
@ -717,6 +727,7 @@ class LLMProvider(ABC):
|
|||||||
*,
|
*,
|
||||||
retry_mode: str,
|
retry_mode: str,
|
||||||
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
||||||
|
should_retry_guard: Callable[[], bool] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
attempt = 0
|
attempt = 0
|
||||||
delays = list(self._CHAT_RETRY_DELAYS)
|
delays = list(self._CHAT_RETRY_DELAYS)
|
||||||
@ -730,6 +741,11 @@ class LLMProvider(ABC):
|
|||||||
if response.finish_reason != "error":
|
if response.finish_reason != "error":
|
||||||
return response
|
return response
|
||||||
last_response = response
|
last_response = response
|
||||||
|
if should_retry_guard is not None and not should_retry_guard():
|
||||||
|
logger.warning(
|
||||||
|
"LLM stream failed after content was emitted; skipping retry"
|
||||||
|
)
|
||||||
|
return response
|
||||||
error_key = ((response.content or "").strip().lower() or None)
|
error_key = ((response.content or "").strip().lower() or None)
|
||||||
if error_key and error_key == last_error_key:
|
if error_key and error_key == last_error_key:
|
||||||
identical_error_count += 1
|
identical_error_count += 1
|
||||||
|
|||||||
@ -105,6 +105,60 @@ def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch
|
|||||||
assert trimmed[0]["role"] == "system"
|
assert trimmed[0]["role"] == "system"
|
||||||
non_system = [m for m in trimmed if m["role"] != "system"]
|
non_system = [m for m in trimmed if m["role"] != "system"]
|
||||||
assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}"
|
assert non_system[0]["role"] == "user", f"Expected user after system, got {non_system[0]['role']}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_snip_history_reserves_budget_for_tool_definitions(monkeypatch):
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = [{"type": "function", "function": {"name": "large_tool"}}]
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old user"},
|
||||||
|
{"role": "assistant", "content": "old assistant"},
|
||||||
|
{"role": "user", "content": "recent one"},
|
||||||
|
{"role": "assistant", "content": "recent answer"},
|
||||||
|
{"role": "user", "content": "recent two"},
|
||||||
|
]
|
||||||
|
spec = AgentRunSpec(
|
||||||
|
initial_messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
context_window_tokens=2000,
|
||||||
|
context_block_limit=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _estimate(_provider, _model, estimate_messages, estimate_tools):
|
||||||
|
if estimate_messages == messages:
|
||||||
|
return 1000, None
|
||||||
|
assert estimate_messages == [{"role": "system", "content": "system"}]
|
||||||
|
assert estimate_tools == tools.get_definitions.return_value
|
||||||
|
return 350, None
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", _estimate)
|
||||||
|
token_sizes = {
|
||||||
|
"system": 50,
|
||||||
|
"old user": 200,
|
||||||
|
"old assistant": 200,
|
||||||
|
"recent one": 200,
|
||||||
|
"recent answer": 200,
|
||||||
|
"recent two": 200,
|
||||||
|
}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runner.estimate_message_tokens",
|
||||||
|
lambda msg: token_sizes.get(str(msg.get("content")), 40),
|
||||||
|
)
|
||||||
|
|
||||||
|
trimmed = runner._snip_history(spec, messages)
|
||||||
|
|
||||||
|
contents = [message.get("content") for message in trimmed]
|
||||||
|
assert contents == ["system", "recent two"]
|
||||||
|
|
||||||
|
|
||||||
async def test_backfill_missing_tool_results_inserts_error():
|
async def test_backfill_missing_tool_results_inserts_error():
|
||||||
"""Orphaned tool_use (no matching tool_result) should get a synthetic error."""
|
"""Orphaned tool_use (no matching tool_result) should get a synthetic error."""
|
||||||
from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT
|
from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT
|
||||||
|
|||||||
@ -554,6 +554,31 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path):
|
|||||||
assert msg.session_key not in loop._pending_queues
|
assert msg.session_key not in loop._pending_queues
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_waiting_dispatch_does_not_replace_active_pending_queue(tmp_path):
|
||||||
|
"""A queued dispatch must not steal the active task's injection queue."""
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
session_key = "cli:c"
|
||||||
|
lock = loop._session_locks.setdefault(session_key, asyncio.Lock())
|
||||||
|
await lock.acquire()
|
||||||
|
active_pending = asyncio.Queue(maxsize=1)
|
||||||
|
loop._pending_queues[session_key] = active_pending
|
||||||
|
|
||||||
|
waiting = asyncio.create_task(
|
||||||
|
loop._dispatch(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="queued"))
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
assert loop._pending_queues[session_key] is active_pending
|
||||||
|
|
||||||
|
waiting.cancel()
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await waiting
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_followup_routed_to_pending_queue(tmp_path):
|
async def test_followup_routed_to_pending_queue(tmp_path):
|
||||||
"""Unified-session follow-ups should route into the active pending queue."""
|
"""Unified-session follow-ups should route into the active pending queue."""
|
||||||
|
|||||||
@ -474,6 +474,32 @@ class TestStopCommandWithUnifiedSession:
|
|||||||
assert task.cancelled() or task.done()
|
assert task.cancelled() or task.done()
|
||||||
assert "Stopped 1 task" in result.content
|
assert "Stopped 1 task" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_command_uses_effective_key_without_session_override(self, tmp_path: Path):
|
||||||
|
"""Priority /stop must cancel the unified session even before dispatch rewrites the message."""
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
from nanobot.command.builtin import cmd_stop
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
async def long_running():
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
task = asyncio.create_task(long_running())
|
||||||
|
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="telegram",
|
||||||
|
chat_id="123456",
|
||||||
|
sender_id="user1",
|
||||||
|
content="/stop",
|
||||||
|
)
|
||||||
|
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||||
|
|
||||||
|
result = await cmd_stop(ctx)
|
||||||
|
|
||||||
|
assert task.cancelled() or task.done()
|
||||||
|
assert "Stopped 1 task" in result.content
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
|
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
|
||||||
"""In unified mode, /stop from one channel cancels tasks from another channel."""
|
"""In unified mode, /stop from one channel cancels tasks from another channel."""
|
||||||
@ -504,4 +530,4 @@ class TestStopCommandWithUnifiedSession:
|
|||||||
result = await cmd_stop(ctx)
|
result = await cmd_stop(ctx)
|
||||||
|
|
||||||
# Both tasks should be cancelled
|
# Both tasks should be cancelled
|
||||||
assert "Stopped 2 task" in result.content
|
assert "Stopped 2 task" in result.content
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -72,6 +73,33 @@ async def test_complete_goal_closes_active_goal(tmp_path):
|
|||||||
assert blob["recap"] == "Done."
|
assert blob["recap"] == "Done."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_goal_tools_keep_request_context_per_task(tmp_path):
|
||||||
|
sm = SessionManager(tmp_path)
|
||||||
|
lt = LongTaskTool(sessions=sm)
|
||||||
|
cg = CompleteGoalTool(sessions=sm)
|
||||||
|
ctx_a = RequestContext(channel="websocket", chat_id="a", session_key="websocket:a")
|
||||||
|
ctx_b = RequestContext(channel="websocket", chat_id="b", session_key="websocket:b")
|
||||||
|
|
||||||
|
lt.set_context(ctx_a)
|
||||||
|
task_a = asyncio.create_task(lt.execute(goal="Goal A"))
|
||||||
|
lt.set_context(ctx_b)
|
||||||
|
task_b = asyncio.create_task(lt.execute(goal="Goal B"))
|
||||||
|
await asyncio.gather(task_a, task_b)
|
||||||
|
|
||||||
|
assert sm.get_or_create("websocket:a").metadata[GOAL_STATE_KEY]["objective"] == "Goal A"
|
||||||
|
assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["objective"] == "Goal B"
|
||||||
|
|
||||||
|
cg.set_context(ctx_a)
|
||||||
|
done_a = asyncio.create_task(cg.execute(recap="Done A"))
|
||||||
|
cg.set_context(ctx_b)
|
||||||
|
done_b = asyncio.create_task(cg.execute(recap="Done B"))
|
||||||
|
await asyncio.gather(done_a, done_b)
|
||||||
|
|
||||||
|
assert sm.get_or_create("websocket:a").metadata[GOAL_STATE_KEY]["recap"] == "Done A"
|
||||||
|
assert sm.get_or_create("websocket:b").metadata[GOAL_STATE_KEY]["recap"] == "Done B"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_long_task_publishes_goal_state_ws_after_save(tmp_path):
|
async def test_long_task_publishes_goal_state_ws_after_save(tmp_path):
|
||||||
bus = MagicMock()
|
bus = MagicMock()
|
||||||
|
|||||||
@ -21,6 +21,17 @@ class ScriptedProvider(LLMProvider):
|
|||||||
raise response
|
raise response
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def chat_stream(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
self.calls += 1
|
||||||
|
self.last_kwargs = kwargs
|
||||||
|
response = self._responses.pop(0)
|
||||||
|
if isinstance(response, BaseException):
|
||||||
|
raise response
|
||||||
|
delta = getattr(response, "_test_stream_delta", None)
|
||||||
|
if delta and kwargs.get("on_content_delta"):
|
||||||
|
await kwargs["on_content_delta"](delta)
|
||||||
|
return response
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return "test-model"
|
return "test-model"
|
||||||
|
|
||||||
@ -122,6 +133,36 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
|||||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_stream_with_retry_does_not_retry_after_emitting_content(monkeypatch) -> None:
|
||||||
|
first = LLMResponse(content="stream stalled", finish_reason="error")
|
||||||
|
first._test_stream_delta = "partial" # type: ignore[attr-defined]
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
first,
|
||||||
|
LLMResponse(content="ok"),
|
||||||
|
])
|
||||||
|
deltas: list[str] = []
|
||||||
|
delays: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: int) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
async def _on_delta(delta: str) -> None:
|
||||||
|
deltas.append(delta)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_stream_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
on_content_delta=_on_delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "stream stalled"
|
||||||
|
assert provider.calls == 1
|
||||||
|
assert deltas == ["partial"]
|
||||||
|
assert delays == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||||
"""When callers omit generation params, provider.generation defaults are used."""
|
"""When callers omit generation params, provider.generation defaults are used."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user