mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
* feat(long-task): add LongTaskTool for multi-step agent tasks
Implements a meta-ReAct loop where long-running tasks are broken into
sequential subagent steps, each starting fresh with the original goal
and progress from the previous step. This prevents context drift when
agents work on complex, multi-step tasks.
- Extract build_tool_registry() from SubagentManager for reuse
- Add run_step() for synchronous subagent execution (no bus announcement)
- Add HandoffTool and CompleteTool as signal mechanisms via shared dict
- Add LongTaskTool orchestrator with simplified prompt (8 iterations/step)
- Register LongTaskTool in main agent loop
- Add _extract_handoff_from_messages fallback for robustness
* fix(long-task): add debug logging for step-level observability
* feat(long-task): major overhaul with structured handoffs, validation, and observability
- Structured HandoffState: HandoffTool now accepts files_created,
files_modified, next_step_hint, and verification fields instead of
a plain string. Progress is passed between steps as structured data.
- Completion validation round: After complete() is called, a dedicated
validator step runs to verify the claim against the original goal.
If validation fails, the task continues rather than returning
a false completion.
- Dynamic prompt system: 3 Jinja2 templates (step_start, step_middle,
step_final) selected based on step number. Final steps get tighter
budget and stronger "wrap up" guidance.
- Automatic file change tracking: Extracts write_file/edit_file events
from tool_events and injects them into the next step's context if
the subagent forgot to report them explicitly.
- Budget tracking & adaptive strategy: Cumulative token usage is tracked
across steps. Per-step tool budget drops from 8 to 4 in the last
two steps to force handoff/completion.
- Crash retry with graceful degradation: A step that crashes is retried
once. Persistent crashes terminate the task and return partial progress.
- Full observability hooks for future WebUI integration:
- set_hooks() with on_step_start, on_step_complete, on_handoff,
on_validation_started, on_validation_passed, on_validation_failed,
on_task_complete, on_task_error, and catch-all on_event.
- Readable state properties: current_step, total_steps, status,
last_handoff, cumulative_usage, goal.
- inject_correction() allows external code to send user corrections
that are injected into the next step's prompt.
- run_step() accepts optional max_iterations for dynamic budget control.
All 27 long-task tests and 11 subagent tests pass.
* test(long-task): add boundary tests and fix race conditions
- Add 7 edge-case tests: validation crash resilience, hook exception safety, mid-run correction injection, FIFO correction ordering, explicit file changes overriding auto-detection, final budget for max_steps=1, and dynamic budget switching boundaries
- Fix assertion in test_long_task_completes_after_multiple_handoffs to match exact prompt format
- Remove asyncio timing hack from test_state_exposure
- Add asyncio.sleep(0) yield in test_inject_correction_during_execution to prevent race between signal injection and step continuation
- All 34 tests passing
* fix(long-task): address code review findings
- Declare _scopes = {"core"} explicitly to prevent recursive nesting in subagent scope
- Document fragile coupling in _extract_file_changes: path extraction depends on
write_file/edit_file detail format; add debug log for unexpected formats
- Align final-template threshold (max_steps - 2) with budget switch threshold
- Eliminate hasattr(self, "_state") in _reset_state by initializing in __init__
* fix(long-task): honor final signal and file tracking
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(long-task): improve prompt structure and agent contract
- Expand LongTaskTool.description to instruct parent agent on goal
construction, return value semantics, and how to handle results.
- Expand CompleteTool.description to emphasize that the summary IS the
final answer returned to the parent agent.
- Prefix validated return value with an explicit "final answer" directive
to stop parent agent from re-running work.
- Redesign step_start.md: Step 1 is now explicitly for exploration,
planning, and skeleton-building. complete() is discouraged.
- Remove bulky payload debug logging from _emit(); add targeted
info/warning/error logs at key state transitions instead.
- Add signal_type to HandoffState for cleaner signal detection.
* test(long-task): expect wrapped completion message after validation
Align assertions with LongTaskTool final return shape on main.
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(webui): turn timing strip, latency, and session-switch restore
- Agent loop: publish goal_status run/idle for WebSocket turns; attach
wall-clock latency_ms on turn_end and persisted assistant metadata.
- WebSocket channel: forward goal_status and latency fields to clients.
- NanobotClient: track goal_status started_at per chat without requiring
onChat; useNanobotStream restores run strip when returning to a chat.
- Thread UI: composer/shell viewport hooks for run duration and latency;
format helpers and i18n strings.
- MessageBubble: drop trailing StreamCursor (layout artifact vs block markdown).
- Builtin / tests: model command coverage, websocket and loop tests.
Covers multi-session UX and round-trip timing visibility for the WebUI.
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix: keep message-tool file attachments after canonical history hydrate
- MessageTool records per-turn media paths delivered to the active chat.
- nanobot.utils.session_attachments stages out-of-media-root files and
merges into the last assistant message before save (loop stays a thin call).
- WebUI MediaCell: use a signed URL as a real download link when present.
Fixes attachments flashing then vanishing on turn_end when paths lived
outside get_media_dir (e.g. workspace files).
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(webui): agent activity cluster, stable keys, LTR sheen labels
- Group reasoning and tool traces in AgentActivityCluster with i18n summaries
- Stabilize React list keys for activity clusters (first message id anchor)
- Replace background-clip shimmer with overlay sheen for streaming labels
- ThreadMessages/MessageList integration and locale strings
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(webui): render assistant reasoning with Markdown + deferred stream
- Use MarkdownText for ReasoningBubble body (same GFM/KaTeX path as replies)
- Apply muted/italic prose tokens so thinking stays visually subordinate
- useDeferredValue while reasoningStreaming to ease parser work during deltas
- Preload markdown chunk when trace opens; add regression test with preloaded renderer
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(webui): default-collapse agent activity cluster while Working
Outer fold no longer auto-expands during isTurnStreaming; user opens to see traces.
Header sheen and live summary unchanged.
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(long_task): cumulative run history, file union, and prompt tuning
Inject cross-step summaries and merged file paths into middle/final step
templates so chains do not lose early context. Strip the last run-history
block when it duplicates Previous Progress to save tokens. Add optional
cumulative_prompt_max_chars and cumulative_step_body_max_chars parameters
with clamped defaults.
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(webui): session switch keeps in-flight thread and replays buffered WS
Save the prior chat message list to the per-chat cache in a layout effect
when chatId changes (before stale writes could corrupt another chat).
Skip one post-switch layout cache tick so we do not snapshot the wrong tab.
Buffer inbound events per chat_id when no onChat subscriber is registered
(e.g. user focused another session) and drain on resubscribe up to a cap,
so streaming deltas are not lost while off-tab.
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(webui): snap thread scroll to bottom on session open (no smooth glide)
Use scroll-behavior auto on the viewport, instant programmatic scroll when
following new messages and on scrollToBottomSignal. Keep smooth only for
the explicit scroll-to-bottom button.
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(webui): respect manual scroll-up after opening a session
Track when the user leaves the bottom with a ref and skip ResizeObserver
and deferred bottom snaps until they return or the conversation is reset.
Remove the time-based force-bottom window that overrode atBottom.
Multi-frame scrollToBottom honours the same guard unless force (scroll button).
Co-authored-by: Cursor <cursoragent@cursor.com>
* Publish long_task UI snapshots on outbound metadata
- Add OUTBOUND_META_AGENT_UI (_agent_ui) for channel-agnostic structured state
- LongTaskTool publishes {kind: long_task, data: snapshot} on the bus with _progress
- WebSocket send forwards metadata as agent_ui for WebUI clients
- Tests for bus payload, WS frame, and progress assertions
- Fix loop progress tests: ignore _goal_status in streaming final filter and
avoid brittle outbound[-1] ordering after goal status idle messages
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat: WebUI long_task activity card and resilient history merge
Add optional ui_summary to the long_task tool for one-line UI labels. Stream
long_task agent_ui into a dedicated message row with timeline, markdown peek,
and a right sheet for details. Merge canonical history after turn_end while
re-inserting long_task rows before the final assistant reply. Collapse
duplicate task_start/step_start steps in the timeline and extend i18n.
Co-authored-by: Cursor <cursoragent@cursor.com>
* refactor: align long_task with thread_goal and drop orchestrator UI
- Persist sustained objectives via session metadata (long_task / complete_goal); no subagent wiring or tool-driven agent_ui payloads.\n- Remove WebUI long-task activity UI, types, and translations; history merge preserves trace replay only, with legacy long_task rows normalized to traces.\n- Drop long_task prompt templates and get_long_task_run_dir; add webui thread disk helper for gateway persistence tests.
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(agent): thread goal runtime context, tools, and skill
- Add thread_goal_state helper and mirror active objectives into Runtime Context
- Wire loop/context/memory/events as needed for goal metadata in turns
- Expand long_task / complete_goal semantics (pivot/cancel/honest recap)
- Add always-on thread-goal SKILL.md; align /goal command prompt
- Tests for context builder and thread goal state
- Remove unused webui ChatPane component
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(thread-goal): add websocket snapshot helper and publish goal updates from long_task
Introduce thread_goal_ws_blob for bounded JSON snapshots, attach snapshots to
websocket turn_end metadata in AgentLoop, and let long_task fan-out dedicated
thread_goal frames on the websocket channel after persisting session metadata.
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(channels): websocket thread_goal frames, turn_end replay, and session API scrub for subagent inject
Emit thread_goal events and optional thread_goal on turn_end; scrub persisted
subagent announce blobs on GET /api/sessions/.../messages and shorten session
list previews so WebUI does not surface full Task/Summarize scaffolding.
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(webui): merge ephemeral traces per user turn when reconciling canonical history
Preserve disk/live trace rows inside the matching user–assistant segment instead
of stacking every trace before the final assistant reply (fixes inflated tool
counts after refresh or session switch).
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(webui): show assistant reply copy only on the last slice before the next user turn
Avoid duplicate copy affordances on intermediate assistant bubbles that precede
more agent activity in the same turn (tools or further assistant text).
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(webui): thread_goal stream plumbing, composer goal strip, sky glow, and client-side subagent scrub projection
Track thread_goal and turn_goal snapshots in NanobotClient, hydrate React state
from thread_goal frames and turn_end, surface objective/elapsed in the composer,
add breathing sky halo CSS while goals are active, mirror server scrub logic on
history hydration and webui_thread snapshots, and extend tests/client mocks.
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(channels): add Slack Socket Mode connect timeout with actionable timeout errors
Abort hung websockets.connect handshakes after a bounded wait, log REST-vs-WSS
guidance, surface RuntimeError to channel startup, and log successful WSS setup.
Co-authored-by: Cursor <cursoragent@cursor.com>
* webui: expand thread goal in composer bottom sheet
Add ChevronUp control on the run/goal strip that opens a bottom Sheet
with full ui_summary and objective. Inline preview logic in RunElapsedStrip,
add i18n strings across locales, and a composer unit test.
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(webui): widen dedupeToolCallsForUi input for session API typing
fetchSessionMessages types tool_calls as unknown; accept unknown so tsc
build passes when passing message.tool_calls through.
Co-authored-by: Cursor <cursoragent@cursor.com>
* refactor(agent): extract WebSocket turn run status to webui_turn_helpers
* refactor(skills): rename thread-goal to long-task and document idempotent goals
* feat(skills): rename sustained-goal skill to long-goal and tighten long_task guidance
* chore: remove unused subagent/context/router helpers
* feat(session): rename sustained goal to goal_state and align WS/WebUI
- Move helpers from agent/thread_goal_state to session/goal_state:
GOAL_STATE_KEY, goal_state_runtime_lines, goal_state_ws_blob, parse_goal_state.
- Session metadata now uses "goal_state"; still read legacy "thread_goal";
long_task writes drop the legacy key after save.
- WebSocket: event/field goal_state, _goal_state_sync; turn_end carries goal_state;
accept legacy _thread_goal_sync/thread_goal inbound metadata for dispatch.
- WebUI: GoalStateWsPayload, goalState hook/client props, i18n keys goalState*.
- Runtime Context copy uses "Goal (active):" instead of "Thread goal".
* feat(agent): stream Anthropic thinking deltas and fix stream idle timeout
* refactor(webui): transcript jsonl as sole timeline source
* fix(agent): reject mismatched WS message chat_id and stream reasoning deltas
* feat(webui): hydrate sustained goal and run timer after websocket subscribe
* chore(webui,websocket): remove unused fetch helpers and legacy thread_goal WS paths
* Raise default max_tokens and context window in agent schema.
Align AgentDefaults and ModelPresetConfig with typical Claude-scale usage
(32k completion budget, 256k context window) and update migration tests.
Co-authored-by: Cursor <cursoragent@cursor.com>
* feat(gateway): bootstrap prefers in-memory model; clarify websocket naming
* fix(websocket): websocket _handle_message passes is_dm; refresh /status test expectations
---------
Co-authored-by: chengyongru <2755839590@qq.com>
Co-authored-by: chengyongru <chengyongru.ai@gmail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
1660 lines
69 KiB
Python
1660 lines
69 KiB
Python
"""Agent loop: the core processing engine."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import dataclasses
|
|
import os
|
|
import time
|
|
from contextlib import AsyncExitStack, nullcontext, suppress
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum, auto
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
|
|
|
from loguru import logger
|
|
|
|
from nanobot.agent import model_presets as preset_helpers
|
|
from nanobot.agent.autocompact import AutoCompact
|
|
from nanobot.agent.context import ContextBuilder
|
|
from nanobot.agent.hook import AgentHook, CompositeHook
|
|
from nanobot.agent.memory import Consolidator, Dream
|
|
from nanobot.agent.progress_hook import AgentProgressHook
|
|
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
|
from nanobot.agent.subagent import SubagentManager
|
|
from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, reset_file_states
|
|
from nanobot.agent.tools.message import MessageTool
|
|
from nanobot.agent.tools.registry import ToolRegistry
|
|
from nanobot.agent.tools.self import MyTool
|
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
|
from nanobot.bus.queue import MessageBus
|
|
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
|
from nanobot.config.schema import AgentDefaults, ModelPresetConfig
|
|
from nanobot.providers.base import LLMProvider
|
|
from nanobot.providers.factory import ProviderSnapshot
|
|
from nanobot.session.goal_state import goal_state_runtime_lines, goal_state_ws_blob
|
|
from nanobot.session.manager import Session, SessionManager
|
|
from nanobot.utils.artifacts import generated_image_paths_from_messages
|
|
from nanobot.utils.document import extract_documents
|
|
from nanobot.utils.helpers import image_placeholder_text
|
|
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
|
from nanobot.utils.image_generation_intent import image_generation_prompt
|
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
|
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
|
|
from nanobot.utils.webui_titles import mark_webui_session, maybe_generate_webui_title_after_turn
|
|
from nanobot.utils.webui_turn_helpers import publish_turn_run_status
|
|
|
|
if TYPE_CHECKING:
|
|
from nanobot.config.schema import (
|
|
ChannelsConfig,
|
|
ProviderConfig,
|
|
ToolsConfig,
|
|
)
|
|
from nanobot.cron.service import CronService
|
|
|
|
|
|
UNIFIED_SESSION_KEY = "unified:default"
|
|
|
|
|
|
class TurnState(Enum):
|
|
RESTORE = auto()
|
|
COMPACT = auto()
|
|
COMMAND = auto()
|
|
BUILD = auto()
|
|
RUN = auto()
|
|
SAVE = auto()
|
|
RESPOND = auto()
|
|
DONE = auto()
|
|
|
|
|
|
@dataclass
|
|
class StateTraceEntry:
|
|
state: TurnState
|
|
started_at: float
|
|
duration_ms: float
|
|
event: str
|
|
error: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class TurnContext:
|
|
msg: InboundMessage
|
|
session_key: str
|
|
state: TurnState
|
|
turn_id: str
|
|
session: Session | None = None
|
|
|
|
history: list[dict[str, Any]] = field(default_factory=list)
|
|
initial_messages: list[dict[str, Any]] = field(default_factory=list)
|
|
|
|
final_content: str | None = None
|
|
tools_used: list[str] = field(default_factory=list)
|
|
all_messages: list[dict[str, Any]] = field(default_factory=list)
|
|
stop_reason: str = ""
|
|
had_injections: bool = False
|
|
|
|
user_persisted_early: bool = False
|
|
save_skip: int = 0
|
|
|
|
outbound: OutboundMessage | None = None
|
|
generated_media: list[str] = field(default_factory=list)
|
|
|
|
on_progress: Callable[..., Awaitable[None]] | None = None
|
|
on_stream: Callable[[str], Awaitable[None]] | None = None
|
|
on_stream_end: Callable[..., Awaitable[None]] | None = None
|
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None
|
|
|
|
pending_queue: asyncio.Queue | None = None
|
|
pending_summary: str | None = None
|
|
|
|
turn_wall_started_at: float = field(default_factory=time.time)
|
|
turn_latency_ms: int | None = None
|
|
|
|
trace: list[StateTraceEntry] = field(default_factory=list)
|
|
|
|
|
|
class AgentLoop:
|
|
"""
|
|
The agent loop is the core processing engine.
|
|
|
|
It:
|
|
1. Receives messages from the bus
|
|
2. Builds context with history, memory, skills
|
|
3. Calls the LLM
|
|
4. Executes tool calls
|
|
5. Sends responses back
|
|
"""
|
|
|
|
@property
|
|
def current_iteration(self) -> int:
|
|
return self._current_iteration
|
|
|
|
@property
|
|
def tool_names(self) -> list[str]:
|
|
return self.tools.tool_names
|
|
|
|
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
|
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
|
|
|
# Event-driven state transition table.
|
|
# Handlers return an event string; the driver looks up the next state here.
|
|
_TRANSITIONS: dict[tuple[TurnState, str], TurnState] = {
|
|
(TurnState.RESTORE, "ok"): TurnState.COMPACT,
|
|
(TurnState.COMPACT, "ok"): TurnState.COMMAND,
|
|
(TurnState.COMMAND, "dispatch"): TurnState.BUILD,
|
|
(TurnState.COMMAND, "shortcut"): TurnState.DONE,
|
|
(TurnState.BUILD, "ok"): TurnState.RUN,
|
|
(TurnState.RUN, "ok"): TurnState.SAVE,
|
|
(TurnState.SAVE, "ok"): TurnState.RESPOND,
|
|
(TurnState.RESPOND, "ok"): TurnState.DONE,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
bus: MessageBus,
|
|
provider: LLMProvider,
|
|
workspace: Path,
|
|
model: str | None = None,
|
|
max_iterations: int | None = None,
|
|
context_window_tokens: int | None = None,
|
|
context_block_limit: int | None = None,
|
|
max_tool_result_chars: int | None = None,
|
|
provider_retry_mode: str = "standard",
|
|
tool_hint_max_length: int | None = None,
|
|
cron_service: CronService | None = None,
|
|
restrict_to_workspace: bool = False,
|
|
session_manager: SessionManager | None = None,
|
|
mcp_servers: dict | None = None,
|
|
channels_config: ChannelsConfig | None = None,
|
|
timezone: str | None = None,
|
|
session_ttl_minutes: int = 0,
|
|
consolidation_ratio: float = 0.5,
|
|
max_messages: int = 120,
|
|
hooks: list[AgentHook] | None = None,
|
|
unified_session: bool = False,
|
|
disabled_skills: list[str] | None = None,
|
|
tools_config: ToolsConfig | None = None,
|
|
image_generation_provider_config: ProviderConfig | None = None,
|
|
image_generation_provider_configs: dict[str, ProviderConfig] | None = None,
|
|
provider_snapshot_loader: Callable[..., ProviderSnapshot] | None = None,
|
|
provider_signature: tuple[object, ...] | None = None,
|
|
model_presets: dict[str, ModelPresetConfig] | None = None,
|
|
model_preset: str | None = None,
|
|
preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None,
|
|
runtime_model_publisher: Callable[[str, str | None], None] | None = None,
|
|
):
|
|
from nanobot.config.schema import ToolsConfig
|
|
|
|
_tc = tools_config or ToolsConfig()
|
|
defaults = AgentDefaults()
|
|
self.bus = bus
|
|
self.channels_config = channels_config
|
|
self.provider = provider
|
|
self._provider_snapshot_loader = provider_snapshot_loader
|
|
self._preset_snapshot_loader = preset_snapshot_loader
|
|
self._runtime_model_publisher = runtime_model_publisher
|
|
self._provider_signature = provider_signature
|
|
self._default_selection_signature = preset_helpers.default_selection_signature(provider_signature)
|
|
self.workspace = workspace
|
|
self.model = model or provider.get_default_model()
|
|
self.max_iterations = (
|
|
max_iterations if max_iterations is not None else defaults.max_tool_iterations
|
|
)
|
|
self.context_window_tokens = (
|
|
context_window_tokens
|
|
if context_window_tokens is not None
|
|
else defaults.context_window_tokens
|
|
)
|
|
self.context_block_limit = context_block_limit
|
|
self.max_tool_result_chars = (
|
|
max_tool_result_chars
|
|
if max_tool_result_chars is not None
|
|
else defaults.max_tool_result_chars
|
|
)
|
|
self.provider_retry_mode = provider_retry_mode
|
|
self.tool_hint_max_length = (
|
|
tool_hint_max_length if tool_hint_max_length is not None
|
|
else defaults.tool_hint_max_length
|
|
)
|
|
self.tools_config = _tc
|
|
self.web_config = _tc.web
|
|
self.exec_config = _tc.exec
|
|
self._image_generation_provider_configs = dict(image_generation_provider_configs or {})
|
|
if (
|
|
image_generation_provider_config is not None
|
|
and "openrouter" not in self._image_generation_provider_configs
|
|
):
|
|
self._image_generation_provider_configs["openrouter"] = image_generation_provider_config
|
|
self.cron_service = cron_service
|
|
self.restrict_to_workspace = restrict_to_workspace
|
|
self._start_time = time.time()
|
|
self._last_usage: dict[str, int] = {}
|
|
self._pending_turn_latency_ms: dict[str, int] = {}
|
|
self._extra_hooks: list[AgentHook] = hooks or []
|
|
|
|
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
|
self.sessions = session_manager or SessionManager(workspace)
|
|
self.tools = ToolRegistry()
|
|
# One file-read/write tracker per logical session. The tool registry is
|
|
# shared by this loop, so tools resolve the active state via contextvars.
|
|
self._file_state_store = FileStateStore()
|
|
self.runner = AgentRunner(provider)
|
|
self.subagents = SubagentManager(
|
|
provider=provider,
|
|
workspace=workspace,
|
|
bus=bus,
|
|
model=self.model,
|
|
tools_config=_tc,
|
|
max_tool_result_chars=self.max_tool_result_chars,
|
|
restrict_to_workspace=restrict_to_workspace,
|
|
disabled_skills=disabled_skills,
|
|
max_iterations=self.max_iterations,
|
|
)
|
|
self._unified_session = unified_session
|
|
self._max_messages = max_messages if max_messages > 0 else 120
|
|
self._running = False
|
|
self._mcp_servers = mcp_servers or {}
|
|
self._mcp_stacks: dict[str, AsyncExitStack] = {}
|
|
self._mcp_connected = False
|
|
self._mcp_connecting = False
|
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
|
self._background_tasks: list[asyncio.Task] = []
|
|
self._session_locks: dict[str, asyncio.Lock] = {}
|
|
# Per-session pending queues for mid-turn message injection.
|
|
# When a session has an active task, new messages for that session
|
|
# are routed here instead of creating a new task.
|
|
self._pending_queues: dict[str, asyncio.Queue] = {}
|
|
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
|
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
|
self._concurrency_gate: asyncio.Semaphore | None = (
|
|
asyncio.Semaphore(_max) if _max > 0 else None
|
|
)
|
|
self.consolidator = Consolidator(
|
|
store=self.context.memory,
|
|
provider=provider,
|
|
model=self.model,
|
|
sessions=self.sessions,
|
|
context_window_tokens=self.context_window_tokens,
|
|
build_messages=self.context.build_messages,
|
|
get_tool_definitions=self.tools.get_definitions,
|
|
max_completion_tokens=provider.generation.max_tokens,
|
|
consolidation_ratio=consolidation_ratio,
|
|
)
|
|
self.auto_compact = AutoCompact(
|
|
sessions=self.sessions,
|
|
consolidator=self.consolidator,
|
|
session_ttl_minutes=session_ttl_minutes,
|
|
)
|
|
self.dream = Dream(
|
|
store=self.context.memory,
|
|
provider=provider,
|
|
model=self.model,
|
|
)
|
|
self.model_presets: dict[str, ModelPresetConfig] = model_presets or {}
|
|
self._active_preset: str | None = None
|
|
if model_preset:
|
|
self.set_model_preset(model_preset, publish_update=False)
|
|
self._register_default_tools()
|
|
self._runtime_vars: dict[str, Any] = {}
|
|
self._current_iteration: int = 0
|
|
self.commands = CommandRouter()
|
|
register_builtin_commands(self.commands)
|
|
|
|
@classmethod
|
|
def from_config(
|
|
cls,
|
|
config: Any,
|
|
bus: MessageBus | None = None,
|
|
**extra: Any,
|
|
) -> AgentLoop:
|
|
"""Create an AgentLoop from config with the common parameter set.
|
|
|
|
Extra keyword arguments are forwarded to ``AgentLoop.__init__``,
|
|
allowing callers to override or extend the standard config-derived
|
|
parameters (e.g. ``cron_service``, ``session_manager``).
|
|
"""
|
|
from nanobot.providers.factory import make_provider
|
|
|
|
if bus is None:
|
|
bus = MessageBus()
|
|
defaults = config.agents.defaults
|
|
provider = extra.pop("provider", None) or make_provider(config)
|
|
resolved = config.resolve_preset()
|
|
model = extra.pop("model", None) or resolved.model
|
|
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
|
provider_snapshot_loader = extra.pop("provider_snapshot_loader", None)
|
|
preset_snapshot_loader = extra.pop("preset_snapshot_loader", None) or preset_helpers.make_preset_snapshot_loader(
|
|
config,
|
|
provider_snapshot_loader,
|
|
)
|
|
return cls(
|
|
bus=bus,
|
|
provider=provider,
|
|
workspace=config.workspace_path,
|
|
model=model,
|
|
max_iterations=defaults.max_tool_iterations,
|
|
context_window_tokens=context_window_tokens,
|
|
context_block_limit=defaults.context_block_limit,
|
|
max_tool_result_chars=defaults.max_tool_result_chars,
|
|
provider_retry_mode=defaults.provider_retry_mode,
|
|
tool_hint_max_length=defaults.tool_hint_max_length,
|
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
|
mcp_servers=config.tools.mcp_servers,
|
|
channels_config=config.channels,
|
|
timezone=defaults.timezone,
|
|
unified_session=defaults.unified_session,
|
|
disabled_skills=defaults.disabled_skills,
|
|
session_ttl_minutes=defaults.session_ttl_minutes,
|
|
consolidation_ratio=defaults.consolidation_ratio,
|
|
max_messages=defaults.max_messages,
|
|
tools_config=config.tools,
|
|
model_presets=preset_helpers.configured_model_presets(config),
|
|
model_preset=defaults.model_preset,
|
|
provider_snapshot_loader=provider_snapshot_loader,
|
|
preset_snapshot_loader=preset_snapshot_loader,
|
|
**extra,
|
|
)
|
|
|
|
def _sync_subagent_runtime_limits(self) -> None:
|
|
"""Keep subagent runtime limits aligned with mutable loop settings."""
|
|
self.subagents.max_iterations = self.max_iterations
|
|
|
|
def _apply_provider_snapshot(
|
|
self,
|
|
snapshot: ProviderSnapshot,
|
|
*,
|
|
publish_update: bool = True,
|
|
model_preset: str | None = None,
|
|
) -> None:
|
|
"""Swap model/provider for future turns without disturbing an active one."""
|
|
provider = snapshot.provider
|
|
model = snapshot.model
|
|
context_window_tokens = snapshot.context_window_tokens
|
|
old_model = self.model
|
|
self.provider = provider
|
|
self.model = model
|
|
self.context_window_tokens = context_window_tokens
|
|
self.runner.provider = provider
|
|
self.subagents.set_provider(provider, model)
|
|
self.consolidator.set_provider(provider, model, context_window_tokens)
|
|
self.dream.set_provider(provider, model)
|
|
self._provider_signature = snapshot.signature
|
|
if publish_update and self._runtime_model_publisher is not None:
|
|
self._runtime_model_publisher(
|
|
self.model,
|
|
model_preset if model_preset is not None else self.model_preset,
|
|
)
|
|
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
|
|
|
def _refresh_provider_snapshot(self) -> None:
|
|
if self._provider_snapshot_loader is None:
|
|
return
|
|
try:
|
|
snapshot = self._provider_snapshot_loader()
|
|
except Exception:
|
|
logger.exception("Failed to refresh provider config")
|
|
return
|
|
default_selection = preset_helpers.default_selection_signature(snapshot.signature)
|
|
if self._active_preset and self._default_selection_signature in (None, default_selection):
|
|
self._default_selection_signature = default_selection
|
|
try:
|
|
snapshot = self._build_model_preset_snapshot(self._active_preset)
|
|
except Exception:
|
|
logger.exception("Failed to refresh active model preset")
|
|
return
|
|
else:
|
|
self._active_preset = None
|
|
self._default_selection_signature = default_selection
|
|
if snapshot.signature == self._provider_signature:
|
|
return
|
|
self._default_selection_signature = preset_helpers.default_selection_signature(snapshot.signature)
|
|
self._apply_provider_snapshot(snapshot)
|
|
|
|
@property
|
|
def model_preset(self) -> str | None:
|
|
return self._active_preset
|
|
|
|
@model_preset.setter
|
|
def model_preset(self, name: str | None) -> None:
|
|
self.set_model_preset(name)
|
|
|
|
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
|
return preset_helpers.build_runtime_preset_snapshot(
|
|
name=name,
|
|
presets=self.model_presets,
|
|
provider=self.provider,
|
|
loader=self._preset_snapshot_loader,
|
|
)
|
|
|
|
def set_model_preset(self, name: str | None, *, publish_update: bool = True) -> None:
|
|
"""Resolve a preset by name and apply all runtime model dependents."""
|
|
name = preset_helpers.normalize_preset_name(name, self.model_presets)
|
|
snapshot = self._build_model_preset_snapshot(name)
|
|
self._apply_provider_snapshot(snapshot, publish_update=publish_update, model_preset=name)
|
|
self._active_preset = name
|
|
|
|
def _register_default_tools(self) -> None:
|
|
"""Register the default set of tools via plugin loader."""
|
|
from nanobot.agent.tools.context import ToolContext
|
|
from nanobot.agent.tools.loader import ToolLoader
|
|
|
|
ctx = ToolContext(
|
|
config=self.tools_config,
|
|
workspace=str(self.workspace),
|
|
bus=self.bus,
|
|
subagent_manager=self.subagents,
|
|
cron_service=self.cron_service,
|
|
sessions=self.sessions,
|
|
provider_snapshot_loader=self._provider_snapshot_loader,
|
|
image_generation_provider_configs=self._image_generation_provider_configs,
|
|
timezone=self.context.timezone or "UTC",
|
|
)
|
|
loader = ToolLoader()
|
|
registered = loader.load(ctx, self.tools)
|
|
|
|
# MyTool needs runtime state reference — manual registration
|
|
if self.tools_config.my.enable:
|
|
self.tools.register(
|
|
MyTool(runtime_state=self, modify_allowed=self.tools_config.my.allow_set)
|
|
)
|
|
registered.append("my")
|
|
|
|
logger.info("Registered {} tools: {}", len(registered), registered)
|
|
|
|
async def _connect_mcp(self) -> None:
|
|
"""Connect to configured MCP servers (one-time, lazy)."""
|
|
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
|
|
return
|
|
self._mcp_connecting = True
|
|
from nanobot.agent.tools.mcp import connect_mcp_servers
|
|
|
|
try:
|
|
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
|
|
if self._mcp_stacks:
|
|
self._mcp_connected = True
|
|
else:
|
|
logger.warning("No MCP servers connected successfully (will retry next message)")
|
|
except asyncio.CancelledError:
|
|
logger.warning("MCP connection cancelled (will retry next message)")
|
|
self._mcp_stacks.clear()
|
|
except BaseException as e:
|
|
logger.warning("Failed to connect MCP servers (will retry next message): {}", e)
|
|
self._mcp_stacks.clear()
|
|
finally:
|
|
self._mcp_connecting = False
|
|
|
|
def _set_tool_context(
|
|
self, channel: str, chat_id: str,
|
|
message_id: str | None = None, metadata: dict | None = None,
|
|
session_key: str | None = None,
|
|
) -> None:
|
|
"""Update context for all tools that need routing info."""
|
|
from nanobot.agent.tools.context import ContextAware, RequestContext
|
|
|
|
if session_key is not None:
|
|
effective_key = session_key
|
|
elif self._unified_session:
|
|
effective_key = UNIFIED_SESSION_KEY
|
|
else:
|
|
effective_key = f"{channel}:{chat_id}"
|
|
|
|
request_ctx = RequestContext(
|
|
channel=channel,
|
|
chat_id=chat_id,
|
|
message_id=message_id,
|
|
session_key=effective_key,
|
|
metadata=dict(metadata or {}),
|
|
)
|
|
|
|
for name in self.tools.tool_names:
|
|
tool = self.tools.get(name)
|
|
if tool and isinstance(tool, ContextAware):
|
|
tool.set_context(request_ctx)
|
|
|
|
@staticmethod
|
|
def _runtime_chat_id(msg: InboundMessage) -> str:
|
|
"""Return the chat id shown in runtime metadata for the model."""
|
|
return str(msg.metadata.get("context_chat_id") or msg.chat_id)
|
|
|
|
async def _build_bus_progress_callback(
|
|
self, msg: InboundMessage
|
|
) -> Callable[..., Awaitable[None]]:
|
|
"""Build a progress callback that publishes to the message bus."""
|
|
|
|
async def _bus_progress(
|
|
content: str,
|
|
*,
|
|
tool_hint: bool = False,
|
|
tool_events: list[dict[str, Any]] | None = None,
|
|
reasoning: bool = False,
|
|
reasoning_end: bool = False,
|
|
) -> None:
|
|
meta = dict(msg.metadata or {})
|
|
meta["_progress"] = True
|
|
meta["_tool_hint"] = tool_hint
|
|
if reasoning:
|
|
meta["_reasoning_delta"] = True
|
|
if reasoning_end:
|
|
meta["_reasoning_end"] = True
|
|
if tool_events:
|
|
meta["_tool_events"] = tool_events
|
|
await self.bus.publish_outbound(
|
|
OutboundMessage(
|
|
channel=msg.channel,
|
|
chat_id=msg.chat_id,
|
|
content=content,
|
|
metadata=meta,
|
|
)
|
|
)
|
|
|
|
return _bus_progress
|
|
|
|
async def _build_retry_wait_callback(
|
|
self, msg: InboundMessage
|
|
) -> Callable[[str], Awaitable[None]]:
|
|
"""Build a retry-wait callback that publishes to the message bus."""
|
|
|
|
async def _on_retry_wait(content: str) -> None:
|
|
meta = dict(msg.metadata or {})
|
|
meta["_retry_wait"] = True
|
|
await self.bus.publish_outbound(
|
|
OutboundMessage(
|
|
channel=msg.channel,
|
|
chat_id=msg.chat_id,
|
|
content=content,
|
|
metadata=meta,
|
|
)
|
|
)
|
|
|
|
return _on_retry_wait
|
|
|
|
def _persist_user_message_early(
|
|
self,
|
|
msg: InboundMessage,
|
|
session: Session,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
"""Persist the triggering user message before the turn starts.
|
|
|
|
Returns True if the message was persisted.
|
|
"""
|
|
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
|
has_text = isinstance(msg.content, str) and msg.content.strip()
|
|
if has_text or media_paths:
|
|
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
|
extra.update(kwargs)
|
|
text = msg.content if isinstance(msg.content, str) else ""
|
|
session.add_message("user", text, **extra)
|
|
self._mark_pending_user_turn(session)
|
|
self.sessions.save(session)
|
|
return True
|
|
return False
|
|
|
|
def _build_initial_messages(
|
|
self,
|
|
msg: InboundMessage,
|
|
session: Session,
|
|
history: list[dict[str, Any]],
|
|
pending_summary: str | None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Build the initial message list for the LLM turn."""
|
|
return self.context.build_messages(
|
|
history=history,
|
|
current_message=image_generation_prompt(msg.content, msg.metadata),
|
|
media=msg.media if msg.media else None,
|
|
channel=msg.channel,
|
|
chat_id=self._runtime_chat_id(msg),
|
|
sender_id=msg.sender_id,
|
|
session_summary=pending_summary,
|
|
session_metadata=session.metadata,
|
|
)
|
|
|
|
async def _dispatch_command_inline(
|
|
self,
|
|
msg: InboundMessage,
|
|
key: str,
|
|
raw: str,
|
|
dispatch_fn: Callable[[CommandContext], Awaitable[OutboundMessage | None]],
|
|
) -> None:
|
|
"""Dispatch a command directly from the run() loop and publish the result."""
|
|
ctx = CommandContext(msg=msg, session=None, key=key, raw=raw, loop=self)
|
|
result = await dispatch_fn(ctx)
|
|
if result:
|
|
await self.bus.publish_outbound(result)
|
|
else:
|
|
logger.warning("Command '{}' matched but dispatch returned None", raw)
|
|
|
|
async def _cancel_active_tasks(self, key: str) -> int:
|
|
"""Cancel and await all active tasks and subagents for *key*.
|
|
|
|
Returns the total number of cancelled tasks + subagents.
|
|
"""
|
|
tasks = self._active_tasks.pop(key, [])
|
|
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
|
for t in tasks:
|
|
with suppress(asyncio.CancelledError, Exception):
|
|
await t
|
|
sub_cancelled = await self.subagents.cancel_by_session(key)
|
|
return cancelled + sub_cancelled
|
|
|
|
def _effective_session_key(self, msg: InboundMessage) -> str:
|
|
"""Return the session key used for task routing and mid-turn injections."""
|
|
if self._unified_session and not msg.session_key_override:
|
|
return UNIFIED_SESSION_KEY
|
|
return msg.session_key
|
|
|
|
def _replay_token_budget(self) -> int:
|
|
"""Derive a token budget for session history replay from the context window."""
|
|
if self.context_window_tokens <= 0:
|
|
return 0
|
|
max_output = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
|
|
try:
|
|
reserved_output = int(max_output)
|
|
except (TypeError, ValueError):
|
|
reserved_output = 4096
|
|
budget = self.context_window_tokens - max(1, reserved_output) - 1024
|
|
return budget if budget > 0 else max(128, self.context_window_tokens // 2)
|
|
|
|
async def _run_agent_loop(
|
|
self,
|
|
initial_messages: list[dict],
|
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
|
*,
|
|
session: Session | None = None,
|
|
channel: str = "cli",
|
|
chat_id: str = "direct",
|
|
message_id: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
session_key: str | None = None,
|
|
pending_queue: asyncio.Queue | None = None,
|
|
) -> tuple[str | None, list[str], list[dict], str, bool]:
|
|
"""Run the agent iteration loop.
|
|
|
|
*on_stream*: called with each content delta during streaming.
|
|
*on_stream_end(resuming)*: called when a streaming session finishes.
|
|
``resuming=True`` means tool calls follow (spinner should restart);
|
|
``resuming=False`` means this is the final response.
|
|
|
|
Returns (final_content, tools_used, messages, stop_reason, had_injections).
|
|
"""
|
|
self._sync_subagent_runtime_limits()
|
|
|
|
loop_hook = AgentProgressHook(
|
|
on_progress=on_progress,
|
|
on_stream=on_stream,
|
|
on_stream_end=on_stream_end,
|
|
channel=channel,
|
|
chat_id=chat_id,
|
|
message_id=message_id,
|
|
metadata=metadata,
|
|
session_key=session_key,
|
|
tool_hint_max_length=self.tool_hint_max_length,
|
|
set_tool_context=self._set_tool_context,
|
|
on_iteration=lambda iteration: setattr(self, "_current_iteration", iteration),
|
|
)
|
|
hook: AgentHook = (
|
|
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
|
)
|
|
|
|
async def _checkpoint(payload: dict[str, Any]) -> None:
|
|
if session is None:
|
|
return
|
|
self._set_runtime_checkpoint(session, payload)
|
|
|
|
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
|
|
"""Drain follow-up messages from the pending queue.
|
|
|
|
When no messages are immediately available but sub-agents
|
|
spawned in this dispatch are still running, blocks until at
|
|
least one result arrives (or timeout). This keeps the runner
|
|
loop alive so subsequent sub-agent completions are consumed
|
|
in-order rather than dispatched separately.
|
|
"""
|
|
if pending_queue is None:
|
|
return []
|
|
|
|
def _to_user_message(pending_msg: InboundMessage) -> dict[str, Any]:
|
|
content = pending_msg.content
|
|
media = pending_msg.media if pending_msg.media else None
|
|
if media:
|
|
content, media = extract_documents(content, media)
|
|
media = media or None
|
|
user_content = self.context._build_user_content(content, media)
|
|
extra = goal_state_runtime_lines(session.metadata) if session is not None else []
|
|
runtime_ctx = self.context._build_runtime_context(
|
|
pending_msg.channel,
|
|
self._runtime_chat_id(pending_msg),
|
|
self.context.timezone,
|
|
sender_id=pending_msg.sender_id,
|
|
supplemental_lines=extra or None,
|
|
)
|
|
if isinstance(user_content, str):
|
|
merged: str | list[dict[str, Any]] = f"{user_content}\n\n{runtime_ctx}"
|
|
else:
|
|
merged = user_content + [{"type": "text", "text": runtime_ctx}]
|
|
return {"role": "user", "content": merged}
|
|
|
|
items: list[dict[str, Any]] = []
|
|
while len(items) < limit:
|
|
try:
|
|
items.append(_to_user_message(pending_queue.get_nowait()))
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
|
|
# Block if nothing drained but sub-agents spawned in this dispatch
|
|
# are still running. Keeps the runner loop alive so subsequent
|
|
# completions are injected in-order rather than dispatched separately.
|
|
if (not items
|
|
and session is not None
|
|
and self.subagents.get_running_count_by_session(session.key) > 0):
|
|
try:
|
|
msg = await asyncio.wait_for(pending_queue.get(), timeout=300)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(
|
|
"Timeout waiting for sub-agent completion in session {}",
|
|
session.key,
|
|
)
|
|
return items
|
|
items.append(_to_user_message(msg))
|
|
while len(items) < limit:
|
|
try:
|
|
items.append(_to_user_message(pending_queue.get_nowait()))
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
|
|
return items
|
|
|
|
active_session_key = session.key if session else session_key
|
|
file_state_token = bind_file_states(self._file_state_store.for_session(active_session_key))
|
|
try:
|
|
result = await self.runner.run(AgentRunSpec(
|
|
initial_messages=initial_messages,
|
|
tools=self.tools,
|
|
model=self.model,
|
|
max_iterations=self.max_iterations,
|
|
max_tool_result_chars=self.max_tool_result_chars,
|
|
hook=hook,
|
|
error_message="Sorry, I encountered an error calling the AI model.",
|
|
concurrent_tools=True,
|
|
workspace=self.workspace,
|
|
session_key=session.key if session else None,
|
|
context_window_tokens=self.context_window_tokens,
|
|
context_block_limit=self.context_block_limit,
|
|
provider_retry_mode=self.provider_retry_mode,
|
|
progress_callback=on_progress,
|
|
stream_progress_deltas=on_stream is not None,
|
|
retry_wait_callback=on_retry_wait,
|
|
checkpoint_callback=_checkpoint,
|
|
injection_callback=_drain_pending,
|
|
))
|
|
finally:
|
|
reset_file_states(file_state_token)
|
|
self._last_usage = result.usage
|
|
if result.stop_reason == "max_iterations":
|
|
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
|
# Push final content through stream so streaming channels (e.g. Feishu)
|
|
# update the card instead of leaving it empty.
|
|
if on_stream and on_stream_end:
|
|
await on_stream(result.final_content or "")
|
|
await on_stream_end(resuming=False)
|
|
elif result.stop_reason == "error":
|
|
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
|
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
|
|
|
|
async def run(self) -> None:
|
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
|
self._running = True
|
|
await self._connect_mcp()
|
|
logger.info("Agent loop started")
|
|
|
|
while self._running:
|
|
try:
|
|
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
|
except asyncio.TimeoutError:
|
|
self.auto_compact.check_expired(
|
|
self._schedule_background,
|
|
active_session_keys=self._pending_queues.keys(),
|
|
)
|
|
continue
|
|
except asyncio.CancelledError:
|
|
# Preserve real task cancellation so shutdown can complete cleanly.
|
|
# Only ignore non-task CancelledError signals that may leak from integrations.
|
|
if not self._running or asyncio.current_task().cancelling():
|
|
raise
|
|
continue
|
|
except Exception as e:
|
|
logger.warning("Error consuming inbound message: {}, continuing...", e)
|
|
continue
|
|
|
|
raw = msg.content.strip()
|
|
if self.commands.is_priority(raw):
|
|
await self._dispatch_command_inline(
|
|
msg, msg.session_key, raw,
|
|
self.commands.dispatch_priority,
|
|
)
|
|
continue
|
|
effective_key = self._effective_session_key(msg)
|
|
# If this session already has an active pending queue (i.e. a task
|
|
# is processing this session), route the message there for mid-turn
|
|
# injection instead of creating a competing task.
|
|
if effective_key in self._pending_queues:
|
|
# Non-priority commands must not be queued for injection;
|
|
# dispatch them directly (same pattern as priority commands).
|
|
if self.commands.is_dispatchable_command(raw):
|
|
await self._dispatch_command_inline(
|
|
msg, effective_key, raw,
|
|
self.commands.dispatch,
|
|
)
|
|
continue
|
|
pending_msg = msg
|
|
if effective_key != msg.session_key:
|
|
pending_msg = dataclasses.replace(
|
|
msg,
|
|
session_key_override=effective_key,
|
|
)
|
|
try:
|
|
self._pending_queues[effective_key].put_nowait(pending_msg)
|
|
except asyncio.QueueFull:
|
|
logger.warning(
|
|
"Pending queue full for session {}, falling back to queued task",
|
|
effective_key,
|
|
)
|
|
else:
|
|
logger.info(
|
|
"Routed follow-up message to pending queue for session {}",
|
|
effective_key,
|
|
)
|
|
continue
|
|
# Compute the effective session key before dispatching
|
|
# This ensures /stop command can find tasks correctly when unified session is enabled
|
|
task = asyncio.create_task(self._dispatch(msg))
|
|
self._active_tasks.setdefault(effective_key, []).append(task)
|
|
task.add_done_callback(
|
|
lambda t, k=effective_key: self._active_tasks.get(k, [])
|
|
and self._active_tasks[k].remove(t)
|
|
if t in self._active_tasks.get(k, [])
|
|
else None
|
|
)
|
|
|
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
|
"""Process a message: per-session serial, cross-session concurrent."""
|
|
session_key = self._effective_session_key(msg)
|
|
if session_key != msg.session_key:
|
|
msg = dataclasses.replace(msg, session_key_override=session_key)
|
|
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
|
gate = self._concurrency_gate or nullcontext()
|
|
|
|
# Register a pending queue so follow-up messages for this session are
|
|
# routed here (mid-turn injection) instead of spawning a new task.
|
|
pending = asyncio.Queue(maxsize=20)
|
|
self._pending_queues[session_key] = pending
|
|
|
|
try:
|
|
async with lock, gate:
|
|
try:
|
|
on_stream = on_stream_end = None
|
|
if msg.metadata.get("_wants_stream"):
|
|
# Split one answer into distinct stream segments.
|
|
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
|
stream_segment = 0
|
|
|
|
def _current_stream_id() -> str:
|
|
return f"{stream_base_id}:{stream_segment}"
|
|
|
|
async def on_stream(delta: str) -> None:
|
|
meta = dict(msg.metadata or {})
|
|
meta["_stream_delta"] = True
|
|
meta["_stream_id"] = _current_stream_id()
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id,
|
|
content=delta,
|
|
metadata=meta,
|
|
))
|
|
|
|
async def on_stream_end(*, resuming: bool = False) -> None:
|
|
nonlocal stream_segment
|
|
meta = dict(msg.metadata or {})
|
|
meta["_stream_end"] = True
|
|
meta["_resuming"] = resuming
|
|
meta["_stream_id"] = _current_stream_id()
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id,
|
|
content="",
|
|
metadata=meta,
|
|
))
|
|
stream_segment += 1
|
|
|
|
response = await self._process_message(
|
|
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
|
pending_queue=pending,
|
|
)
|
|
if response is not None:
|
|
await self.bus.publish_outbound(response)
|
|
elif msg.channel == "cli":
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id,
|
|
content="", metadata=msg.metadata or {},
|
|
))
|
|
if msg.channel == "websocket":
|
|
# Signal that the turn is fully complete (all tools executed,
|
|
# final text streamed). This lets WS clients know when to
|
|
# definitively stop the loading indicator.
|
|
turn_lat = self._pending_turn_latency_ms.pop(session_key, None)
|
|
turn_metadata: dict[str, Any] = {**msg.metadata, "_turn_end": True}
|
|
if turn_lat is not None:
|
|
turn_metadata["latency_ms"] = int(turn_lat)
|
|
sess_turn = self.sessions.get_or_create(session_key)
|
|
turn_metadata["goal_state"] = goal_state_ws_blob(sess_turn.metadata)
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id,
|
|
content="", metadata=turn_metadata,
|
|
))
|
|
if msg.metadata.get("webui") is True:
|
|
async def _generate_title_and_notify() -> None:
|
|
generated = await maybe_generate_webui_title_after_turn(
|
|
channel=msg.channel,
|
|
metadata=msg.metadata,
|
|
sessions=self.sessions,
|
|
session_key=session_key,
|
|
provider=self.provider,
|
|
model=self.model,
|
|
)
|
|
if generated:
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel,
|
|
chat_id=msg.chat_id,
|
|
content="",
|
|
metadata={**msg.metadata, "_session_updated": True},
|
|
))
|
|
|
|
self._schedule_background(_generate_title_and_notify())
|
|
except asyncio.CancelledError:
|
|
logger.info("Task cancelled for session {}", session_key)
|
|
# Preserve partial context from the interrupted turn so
|
|
# the user does not lose tool results and assistant
|
|
# messages accumulated before /stop. The checkpoint was
|
|
# already persisted to session metadata by
|
|
# _emit_checkpoint during tool execution; materializing
|
|
# it into session history now makes it visible in the
|
|
# next conversation turn.
|
|
try:
|
|
key = self._effective_session_key(msg)
|
|
session = self.sessions.get_or_create(key)
|
|
if self._restore_runtime_checkpoint(session):
|
|
self._clear_pending_user_turn(session)
|
|
self.sessions.save(session)
|
|
logger.info(
|
|
"Restored partial context for cancelled session {}",
|
|
key,
|
|
)
|
|
except Exception:
|
|
logger.debug(
|
|
"Could not restore checkpoint for cancelled session {}",
|
|
session_key,
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
except Exception:
|
|
logger.exception("Error processing message for session {}", session_key)
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id,
|
|
content="Sorry, I encountered an error.",
|
|
))
|
|
finally:
|
|
# Drain any messages still in the pending queue and re-publish
|
|
# them to the bus so they are processed as fresh inbound messages
|
|
# rather than silently lost.
|
|
queue = self._pending_queues.pop(session_key, None)
|
|
if queue is not None:
|
|
leftover = 0
|
|
while True:
|
|
try:
|
|
item = queue.get_nowait()
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
await self.bus.publish_inbound(item)
|
|
leftover += 1
|
|
if leftover:
|
|
logger.info(
|
|
"Re-published {} leftover message(s) to bus for session {}",
|
|
leftover, session_key,
|
|
)
|
|
await publish_turn_run_status(self.bus, msg, "idle")
|
|
self._pending_turn_latency_ms.pop(session_key, None)
|
|
|
|
async def close_mcp(self) -> None:
|
|
"""Drain pending background archives, then close MCP connections."""
|
|
if self._background_tasks:
|
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
|
self._background_tasks.clear()
|
|
for name, stack in self._mcp_stacks.items():
|
|
try:
|
|
await stack.aclose()
|
|
except (RuntimeError, BaseExceptionGroup):
|
|
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
|
|
self._mcp_stacks.clear()
|
|
|
|
def _schedule_background(self, coro) -> None:
|
|
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
|
task = asyncio.create_task(coro)
|
|
self._background_tasks.append(task)
|
|
task.add_done_callback(self._background_tasks.remove)
|
|
|
|
def stop(self) -> None:
|
|
"""Stop the agent loop."""
|
|
self._running = False
|
|
logger.info("Agent loop stopping")
|
|
|
|
async def _process_system_message(
|
|
self,
|
|
msg: InboundMessage,
|
|
session_key: str | None = None,
|
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
|
pending_queue: asyncio.Queue | None = None,
|
|
) -> OutboundMessage | None:
|
|
"""Process a system inbound message (e.g. subagent announce)."""
|
|
channel, chat_id = (
|
|
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
|
)
|
|
logger.info("Processing system message from {}", msg.sender_id)
|
|
key = msg.session_key_override or f"{channel}:{chat_id}"
|
|
session = self.sessions.get_or_create(key)
|
|
if self._restore_runtime_checkpoint(session):
|
|
self.sessions.save(session)
|
|
if self._restore_pending_user_turn(session):
|
|
self.sessions.save(session)
|
|
|
|
session, pending = self.auto_compact.prepare_session(session, key)
|
|
if pending:
|
|
logger.info("Memory compact triggered for session {}", key)
|
|
|
|
await self.consolidator.maybe_consolidate_by_tokens(
|
|
session,
|
|
replay_max_messages=self._max_messages,
|
|
)
|
|
is_subagent = msg.sender_id == "subagent"
|
|
if is_subagent and self._persist_subagent_followup(session, msg):
|
|
logger.debug("Subagent result persisted for session {}", key)
|
|
self.sessions.save(session)
|
|
self._set_tool_context(
|
|
channel, chat_id, msg.metadata.get("message_id"),
|
|
msg.metadata, session_key=key,
|
|
)
|
|
_hist_kwargs: dict[str, Any] = {
|
|
"max_messages": self._max_messages,
|
|
"max_tokens": self._replay_token_budget(),
|
|
"include_timestamps": True,
|
|
}
|
|
history = session.get_history(**_hist_kwargs)
|
|
current_role = "assistant" if is_subagent else "user"
|
|
|
|
messages = self.context.build_messages(
|
|
history=history,
|
|
current_message="" if is_subagent else msg.content,
|
|
channel=channel,
|
|
chat_id=chat_id,
|
|
current_role=current_role,
|
|
sender_id=msg.sender_id,
|
|
session_summary=pending,
|
|
session_metadata=session.metadata,
|
|
)
|
|
t_wall = time.time()
|
|
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
|
messages, session=session, channel=channel, chat_id=chat_id,
|
|
message_id=msg.metadata.get("message_id"),
|
|
metadata=msg.metadata,
|
|
session_key=key,
|
|
pending_queue=pending_queue,
|
|
)
|
|
wall_done = time.time()
|
|
latency_ms = max(0, int((wall_done - t_wall) * 1000))
|
|
self._save_turn(session, all_msgs, 1 + len(history), turn_latency_ms=latency_ms)
|
|
if channel == "websocket":
|
|
self._pending_turn_latency_ms[key] = latency_ms
|
|
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
|
self._clear_runtime_checkpoint(session)
|
|
self.sessions.save(session)
|
|
self._schedule_background(
|
|
self.consolidator.maybe_consolidate_by_tokens(
|
|
session,
|
|
replay_max_messages=self._max_messages,
|
|
)
|
|
)
|
|
content = final_content or "Background task completed."
|
|
outbound_metadata: dict[str, Any] = {}
|
|
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
|
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
|
if origin_message_id := msg.metadata.get("origin_message_id"):
|
|
outbound_metadata["origin_message_id"] = origin_message_id
|
|
return OutboundMessage(
|
|
channel=channel,
|
|
chat_id=chat_id,
|
|
content=content,
|
|
metadata=outbound_metadata,
|
|
)
|
|
|
|
async def _process_message(
|
|
self,
|
|
msg: InboundMessage,
|
|
session_key: str | None = None,
|
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
|
pending_queue: asyncio.Queue | None = None,
|
|
) -> OutboundMessage | None:
|
|
"""Process a single inbound message and return the response."""
|
|
self._refresh_provider_snapshot()
|
|
|
|
if msg.channel == "system":
|
|
return await self._process_system_message(
|
|
msg,
|
|
session_key=session_key,
|
|
on_progress=on_progress,
|
|
on_stream=on_stream,
|
|
on_stream_end=on_stream_end,
|
|
pending_queue=pending_queue,
|
|
)
|
|
|
|
key = session_key or msg.session_key
|
|
ctx = TurnContext(
|
|
msg=msg,
|
|
session=None,
|
|
session_key=key,
|
|
state=TurnState.RESTORE,
|
|
turn_id=f"{key}:{time.time_ns()}",
|
|
on_progress=on_progress,
|
|
on_stream=on_stream,
|
|
on_stream_end=on_stream_end,
|
|
pending_queue=pending_queue,
|
|
)
|
|
|
|
while ctx.state is not TurnState.DONE:
|
|
handler_name = f"_state_{ctx.state.name.lower()}"
|
|
handler = getattr(self, handler_name, None)
|
|
if handler is None:
|
|
raise RuntimeError(f"Missing state handler for {ctx.state}")
|
|
|
|
t0 = time.perf_counter()
|
|
try:
|
|
event = await handler(ctx)
|
|
except Exception:
|
|
duration = (time.perf_counter() - t0) * 1000
|
|
ctx.trace.append(
|
|
StateTraceEntry(
|
|
state=ctx.state,
|
|
started_at=t0,
|
|
duration_ms=duration,
|
|
event="",
|
|
error="exception",
|
|
)
|
|
)
|
|
raise
|
|
|
|
duration = (time.perf_counter() - t0) * 1000
|
|
ctx.trace.append(
|
|
StateTraceEntry(
|
|
state=ctx.state,
|
|
started_at=t0,
|
|
duration_ms=duration,
|
|
event=event,
|
|
)
|
|
)
|
|
logger.debug(
|
|
"[turn {}] State {} took {:.1f}ms -> event {}",
|
|
ctx.turn_id,
|
|
ctx.state.name,
|
|
duration,
|
|
event,
|
|
)
|
|
|
|
next_state = self._TRANSITIONS.get((ctx.state, event))
|
|
if next_state is None:
|
|
raise RuntimeError(
|
|
f"[turn {ctx.turn_id}] No transition from {ctx.state} "
|
|
f"on event {event!r}"
|
|
)
|
|
ctx.state = next_state
|
|
|
|
logger.debug(
|
|
"[turn {}] Turn completed after {} states",
|
|
ctx.turn_id,
|
|
len(ctx.trace),
|
|
)
|
|
return ctx.outbound
|
|
|
|
def _assemble_outbound(
|
|
self,
|
|
msg: InboundMessage,
|
|
final_content: str,
|
|
all_msgs: list[dict[str, Any]],
|
|
stop_reason: str,
|
|
had_injections: bool,
|
|
generated_media: list[str],
|
|
on_stream: Callable[[str], Awaitable[None]] | None,
|
|
*,
|
|
turn_latency_ms: int | None = None,
|
|
) -> OutboundMessage | None:
|
|
"""Assemble the final outbound message from turn results."""
|
|
# MessageTool suppression
|
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
|
if not had_injections or stop_reason == "empty_final_response":
|
|
return None
|
|
|
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
|
|
|
meta = dict(msg.metadata or {})
|
|
if on_stream is not None and stop_reason not in {"error", "tool_error"}:
|
|
meta["_streamed"] = True
|
|
if turn_latency_ms is not None:
|
|
meta["latency_ms"] = int(turn_latency_ms)
|
|
|
|
return OutboundMessage(
|
|
channel=msg.channel,
|
|
chat_id=msg.chat_id,
|
|
content=final_content,
|
|
media=generated_media,
|
|
metadata=meta,
|
|
)
|
|
|
|
async def _state_restore(self, ctx: TurnContext) -> TurnState:
|
|
"""Restore checkpoint / pending user turn; extract documents."""
|
|
msg = ctx.msg
|
|
|
|
if msg.media:
|
|
new_content, image_only = extract_documents(msg.content, msg.media)
|
|
ctx.msg = dataclasses.replace(msg, content=new_content, media=image_only)
|
|
msg = ctx.msg
|
|
|
|
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
|
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
|
|
|
# Session is already fetched by the caller (_process_message) but
|
|
# ensure it exists in case this handler is invoked independently.
|
|
if ctx.session is None:
|
|
ctx.session = self.sessions.get_or_create(ctx.session_key)
|
|
mark_webui_session(ctx.session, msg.metadata)
|
|
|
|
if self._restore_runtime_checkpoint(ctx.session):
|
|
self.sessions.save(ctx.session)
|
|
if self._restore_pending_user_turn(ctx.session):
|
|
self.sessions.save(ctx.session)
|
|
|
|
return "ok"
|
|
|
|
async def _state_compact(self, ctx: TurnContext) -> str:
|
|
ctx.session, pending = self.auto_compact.prepare_session(ctx.session, ctx.session_key)
|
|
ctx.pending_summary = pending
|
|
return "ok"
|
|
|
|
async def _state_command(self, ctx: TurnContext) -> str:
|
|
raw = ctx.msg.content.strip()
|
|
cmd_ctx = CommandContext(
|
|
msg=ctx.msg, session=ctx.session, key=ctx.session_key, raw=raw, loop=self
|
|
)
|
|
result = await self.commands.dispatch(cmd_ctx)
|
|
if result is not None:
|
|
ctx.outbound = result
|
|
# Shortcut commands skip BUILD and SAVE, so we must persist the
|
|
# turn here so WebUI history hydration after _turn_end sees the
|
|
# message. Mark messages with _command so get_history can filter
|
|
# them out of LLM context. /new is excluded because it
|
|
# intentionally clears the session.
|
|
if raw.lower() != "/new":
|
|
ctx.user_persisted_early = self._persist_user_message_early(
|
|
ctx.msg, ctx.session, _command=True
|
|
)
|
|
ctx.session.add_message(
|
|
"assistant", result.content, _command=True
|
|
)
|
|
self.sessions.save(ctx.session)
|
|
self._clear_pending_user_turn(ctx.session)
|
|
return "shortcut"
|
|
return "dispatch"
|
|
|
|
async def _state_build(self, ctx: TurnContext) -> str:
|
|
await self.consolidator.maybe_consolidate_by_tokens(
|
|
ctx.session,
|
|
replay_max_messages=self._max_messages,
|
|
)
|
|
self._set_tool_context(
|
|
ctx.msg.channel,
|
|
ctx.msg.chat_id,
|
|
ctx.msg.metadata.get("message_id"),
|
|
ctx.msg.metadata,
|
|
session_key=ctx.session_key,
|
|
)
|
|
if message_tool := self.tools.get("message"):
|
|
if isinstance(message_tool, MessageTool):
|
|
message_tool.start_turn()
|
|
|
|
_hist_kwargs: dict[str, Any] = {
|
|
"max_messages": self._max_messages,
|
|
"max_tokens": self._replay_token_budget(),
|
|
"include_timestamps": True,
|
|
}
|
|
ctx.history = ctx.session.get_history(**_hist_kwargs)
|
|
|
|
ctx.initial_messages = self._build_initial_messages(
|
|
ctx.msg, ctx.session, ctx.history, ctx.pending_summary
|
|
)
|
|
ctx.user_persisted_early = self._persist_user_message_early(
|
|
ctx.msg, ctx.session
|
|
)
|
|
|
|
if ctx.on_progress is None:
|
|
ctx.on_progress = await self._build_bus_progress_callback(ctx.msg)
|
|
if ctx.on_retry_wait is None:
|
|
ctx.on_retry_wait = await self._build_retry_wait_callback(ctx.msg)
|
|
|
|
return "ok"
|
|
|
|
async def _state_run(self, ctx: TurnContext) -> str:
|
|
await publish_turn_run_status(self.bus, ctx.msg, "running")
|
|
result = await self._run_agent_loop(
|
|
ctx.initial_messages,
|
|
on_progress=ctx.on_progress,
|
|
on_stream=ctx.on_stream,
|
|
on_stream_end=ctx.on_stream_end,
|
|
on_retry_wait=ctx.on_retry_wait,
|
|
session=ctx.session,
|
|
channel=ctx.msg.channel,
|
|
chat_id=ctx.msg.chat_id,
|
|
message_id=ctx.msg.metadata.get("message_id"),
|
|
metadata=ctx.msg.metadata,
|
|
session_key=ctx.session_key,
|
|
pending_queue=ctx.pending_queue,
|
|
)
|
|
final_content, tools_used, all_msgs, stop_reason, had_injections = result
|
|
ctx.final_content = final_content
|
|
ctx.tools_used = tools_used
|
|
ctx.all_messages = all_msgs
|
|
ctx.stop_reason = stop_reason
|
|
ctx.had_injections = had_injections
|
|
return "ok"
|
|
|
|
async def _state_save(self, ctx: TurnContext) -> str:
|
|
if ctx.final_content is None or not ctx.final_content.strip():
|
|
ctx.final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
|
|
|
ctx.save_skip = 1 + len(ctx.history) + (1 if ctx.user_persisted_early else 0)
|
|
skip_msgs = ctx.all_messages[ctx.save_skip:]
|
|
ctx.generated_media = generated_image_paths_from_messages(skip_msgs)
|
|
mt = self.tools.get("message")
|
|
extra = getattr(mt, "turn_delivered_media_paths", lambda: [])() if mt else []
|
|
merge_turn_media_into_last_assistant(ctx.all_messages, ctx.generated_media, extra)
|
|
|
|
ctx.turn_latency_ms = max(0, int((time.time() - ctx.turn_wall_started_at) * 1000))
|
|
self._save_turn(
|
|
ctx.session, ctx.all_messages, ctx.save_skip,
|
|
turn_latency_ms=ctx.turn_latency_ms,
|
|
)
|
|
if ctx.msg.channel == "websocket":
|
|
self._pending_turn_latency_ms[ctx.session_key] = ctx.turn_latency_ms
|
|
ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
|
self._clear_pending_user_turn(ctx.session)
|
|
self._clear_runtime_checkpoint(ctx.session)
|
|
self.sessions.save(ctx.session)
|
|
self._schedule_background(
|
|
self.consolidator.maybe_consolidate_by_tokens(
|
|
ctx.session,
|
|
replay_max_messages=self._max_messages,
|
|
)
|
|
)
|
|
return "ok"
|
|
|
|
async def _state_respond(self, ctx: TurnContext) -> str:
|
|
ctx.outbound = self._assemble_outbound(
|
|
ctx.msg,
|
|
ctx.final_content,
|
|
ctx.all_messages,
|
|
ctx.stop_reason,
|
|
ctx.had_injections,
|
|
ctx.generated_media,
|
|
ctx.on_stream,
|
|
turn_latency_ms=ctx.turn_latency_ms,
|
|
)
|
|
return "ok"
|
|
|
|
def _sanitize_persisted_blocks(
|
|
self,
|
|
content: list[dict[str, Any]],
|
|
*,
|
|
should_truncate_text: bool = False,
|
|
drop_runtime: bool = False,
|
|
) -> list[dict[str, Any]]:
|
|
"""Strip volatile multimodal payloads before writing session history."""
|
|
filtered: list[dict[str, Any]] = []
|
|
for block in content:
|
|
if not isinstance(block, dict):
|
|
filtered.append(block)
|
|
continue
|
|
|
|
if (
|
|
drop_runtime
|
|
and block.get("type") == "text"
|
|
and isinstance(block.get("text"), str)
|
|
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
|
|
):
|
|
continue
|
|
|
|
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
|
"url", ""
|
|
).startswith("data:image/"):
|
|
path = (block.get("_meta") or {}).get("path", "")
|
|
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
|
continue
|
|
|
|
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
|
text = block["text"]
|
|
if should_truncate_text and len(text) > self.max_tool_result_chars:
|
|
text = truncate_text_fn(text, self.max_tool_result_chars)
|
|
filtered.append({**block, "text": text})
|
|
continue
|
|
|
|
filtered.append(block)
|
|
|
|
return filtered
|
|
|
|
def _save_turn(
|
|
self,
|
|
session: Session,
|
|
messages: list[dict],
|
|
skip: int,
|
|
*,
|
|
turn_latency_ms: int | None = None,
|
|
) -> None:
|
|
"""Save new-turn messages into session, truncating large tool results."""
|
|
from datetime import datetime
|
|
|
|
last_assistant_idx: int | None = None
|
|
for m in messages[skip:]:
|
|
entry = dict(m)
|
|
role, content = entry.get("role"), entry.get("content")
|
|
if role == "assistant" and not content and not entry.get("tool_calls"):
|
|
continue # skip empty assistant messages — they poison session context
|
|
if role == "tool":
|
|
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
|
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
|
|
elif isinstance(content, list):
|
|
filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True)
|
|
if not filtered:
|
|
continue
|
|
entry["content"] = filtered
|
|
elif role == "user":
|
|
if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content:
|
|
# Strip the runtime-context block appended at the end.
|
|
tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG)
|
|
before = content[:tag_pos].rstrip("\n ")
|
|
if before:
|
|
entry["content"] = before
|
|
else:
|
|
continue
|
|
if isinstance(content, list):
|
|
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
|
|
if not filtered:
|
|
continue
|
|
entry["content"] = filtered
|
|
entry.setdefault("timestamp", datetime.now().isoformat())
|
|
session.messages.append(entry)
|
|
if role == "assistant":
|
|
last_assistant_idx = len(session.messages) - 1
|
|
if turn_latency_ms is not None and last_assistant_idx is not None:
|
|
session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms)
|
|
session.updated_at = datetime.now()
|
|
|
|
def _persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool:
|
|
"""Persist subagent follow-ups before prompt assembly so history stays durable.
|
|
|
|
Returns True if a new entry was appended; False if the follow-up was
|
|
deduped (same ``subagent_task_id`` already in session) or carries no
|
|
content worth persisting.
|
|
"""
|
|
if not msg.content:
|
|
return False
|
|
task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None
|
|
if task_id and any(
|
|
m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id
|
|
for m in session.messages
|
|
):
|
|
return False
|
|
session.add_message(
|
|
"assistant",
|
|
msg.content,
|
|
sender_id=msg.sender_id,
|
|
injected_event="subagent_result",
|
|
subagent_task_id=task_id,
|
|
)
|
|
return True
|
|
|
|
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
|
|
"""Persist the latest in-flight turn state into session metadata."""
|
|
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
|
self.sessions.save(session)
|
|
|
|
def _mark_pending_user_turn(self, session: Session) -> None:
|
|
session.metadata[self._PENDING_USER_TURN_KEY] = True
|
|
|
|
def _clear_pending_user_turn(self, session: Session) -> None:
|
|
session.metadata.pop(self._PENDING_USER_TURN_KEY, None)
|
|
|
|
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
|
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
|
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
|
|
|
@staticmethod
|
|
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
|
|
return (
|
|
message.get("role"),
|
|
message.get("content"),
|
|
message.get("tool_call_id"),
|
|
message.get("name"),
|
|
message.get("tool_calls"),
|
|
message.get("reasoning_content"),
|
|
message.get("thinking_blocks"),
|
|
)
|
|
|
|
def _restore_runtime_checkpoint(self, session: Session) -> bool:
|
|
"""Materialize an unfinished turn into session history before a new request."""
|
|
from datetime import datetime
|
|
|
|
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
|
|
if not isinstance(checkpoint, dict):
|
|
return False
|
|
|
|
assistant_message = checkpoint.get("assistant_message")
|
|
completed_tool_results = checkpoint.get("completed_tool_results") or []
|
|
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
|
|
|
|
restored_messages: list[dict[str, Any]] = []
|
|
if isinstance(assistant_message, dict):
|
|
restored = dict(assistant_message)
|
|
restored.setdefault("timestamp", datetime.now().isoformat())
|
|
restored_messages.append(restored)
|
|
for message in completed_tool_results:
|
|
if isinstance(message, dict):
|
|
restored = dict(message)
|
|
restored.setdefault("timestamp", datetime.now().isoformat())
|
|
restored_messages.append(restored)
|
|
for tool_call in pending_tool_calls:
|
|
if not isinstance(tool_call, dict):
|
|
continue
|
|
tool_id = tool_call.get("id")
|
|
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
|
restored_messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_id,
|
|
"name": name,
|
|
"content": "Error: Task interrupted before this tool finished.",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
)
|
|
|
|
overlap = 0
|
|
max_overlap = min(len(session.messages), len(restored_messages))
|
|
for size in range(max_overlap, 0, -1):
|
|
existing = session.messages[-size:]
|
|
restored = restored_messages[:size]
|
|
if all(
|
|
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
|
|
for left, right in zip(existing, restored)
|
|
):
|
|
overlap = size
|
|
break
|
|
session.messages.extend(restored_messages[overlap:])
|
|
|
|
self._clear_pending_user_turn(session)
|
|
self._clear_runtime_checkpoint(session)
|
|
return True
|
|
|
|
def _restore_pending_user_turn(self, session: Session) -> bool:
|
|
"""Close a turn that only persisted the user message before crashing."""
|
|
from datetime import datetime
|
|
|
|
if not session.metadata.get(self._PENDING_USER_TURN_KEY):
|
|
return False
|
|
|
|
if session.messages and session.messages[-1].get("role") == "user":
|
|
session.messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": "Error: Task interrupted before a response was generated.",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
)
|
|
session.updated_at = datetime.now()
|
|
|
|
self._clear_pending_user_turn(session)
|
|
return True
|
|
|
|
async def process_direct(
|
|
self,
|
|
content: str,
|
|
session_key: str = "cli:direct",
|
|
channel: str = "cli",
|
|
chat_id: str = "direct",
|
|
media: list[str] | None = None,
|
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
|
) -> OutboundMessage | None:
|
|
"""Process a message directly and return the outbound payload."""
|
|
await self._connect_mcp()
|
|
msg = InboundMessage(
|
|
channel=channel, sender_id="user", chat_id=chat_id,
|
|
content=content, media=media or [],
|
|
)
|
|
return await self._process_message(
|
|
msg,
|
|
session_key=session_key,
|
|
on_progress=on_progress,
|
|
on_stream=on_stream,
|
|
on_stream_end=on_stream_end,
|
|
)
|