From 3a420136bbce81f80065b3edfc2dcc8f5ef0fdcb Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Fri, 29 May 2026 03:42:53 +0800 Subject: [PATCH] feat(webui): add project workspaces and access controls (#4007) * feat(webui): add project workspaces and access controls * feat(webui): add project workspaces and access controls * refactor(tools): centralize workspace access resolution * refactor(webui): remove unused workspace host state * fix(webui): hide estimated file edit label * fix(webui): clarify file edit deletion feedback * fix(webui): label deleted file activity * fix(webui): flatten file edit activity rows * fix(core): remove path-only patch deletion * fix(core): keep apply patch non-destructive * refactor(webui): trim workspace host plumbing * fix(tools): register exec with tools config --- .gitignore | 2 + nanobot/agent/context.py | 33 +- nanobot/agent/loop.py | 71 +- nanobot/agent/subagent.py | 73 +- nanobot/agent/tools/apply_patch.py | 76 +- nanobot/agent/tools/cli_apps.py | 10 +- nanobot/agent/tools/context.py | 24 + nanobot/agent/tools/exec_session.py | 25 +- nanobot/agent/tools/filesystem.py | 26 +- nanobot/agent/tools/image_generation.py | 32 +- nanobot/agent/tools/message.py | 13 +- nanobot/agent/tools/path_utils.py | 34 +- nanobot/agent/tools/runtime_state.py | 3 + nanobot/agent/tools/search.py | 5 +- nanobot/agent/tools/self.py | 21 +- nanobot/agent/tools/shell.py | 58 +- nanobot/agent/tools/spawn.py | 2 + nanobot/apps/cli/service.py | 16 +- nanobot/channels/manager.py | 11 +- nanobot/channels/matrix.py | 8 +- nanobot/channels/websocket.py | 341 +++++++- nanobot/cli/commands.py | 139 +++- nanobot/config/schema.py | 16 +- nanobot/providers/openai_codex_provider.py | 43 +- .../providers/openai_responses/__init__.py | 2 + nanobot/providers/openai_responses/parsing.py | 107 ++- nanobot/security/network.py | 30 +- nanobot/security/workspace_access.py | 430 ++++++++++ nanobot/security/workspace_policy.py | 85 ++ nanobot/utils/file_edit_events.py | 13 +- nanobot/webui/mcp_presets_api.py | 8 +- nanobot/webui/settings_api.py | 356 +++++++- nanobot/webui/sidebar_state.py | 5 +- nanobot/webui/transcript.py | 185 ++++- nanobot/webui/workspaces.py | 283 +++++++ .../test_loop_direct_websocket_status.py | 55 ++ tests/agent/test_workspace_scope.py | 344 ++++++++ tests/channels/test_websocket_channel.py | 468 ++++++++++- tests/channels/test_websocket_http_routes.py | 12 + tests/cli/test_commands.py | 29 + tests/cli_apps/test_service.py | 2 + tests/config/test_config_migration.py | 21 + tests/providers/test_openai_codex_provider.py | 61 +- tests/providers/test_openai_responses.py | 204 ++++- tests/security/test_security_network.py | 17 +- tests/security/test_workspace_policy.py | 69 ++ tests/security/test_workspace_sandbox.py | 68 ++ tests/test_tool_contextvars.py | 3 + tests/tools/test_apply_patch_tool.py | 52 +- tests/tools/test_exec_security.py | 65 ++ tests/tools/test_tool_loader.py | 27 +- tests/tools/test_web_fetch_security.py | 19 + tests/utils/test_file_edit_events.py | 6 - tests/utils/test_webui_sidebar_state.py | 4 + tests/utils/test_webui_transcript.py | 97 +++ tests/utils/test_webui_workspaces.py | 154 ++++ tests/webui/test_settings_api.py | 248 +++++- webui/src/App.tsx | 566 ++++++++++--- webui/src/components/ChatList.tsx | 664 ++++++++------- webui/src/components/RenameChatDialog.tsx | 12 +- webui/src/components/Sidebar.tsx | 140 +--- .../src/components/settings/SettingsView.tsx | 770 +++++++++++++++--- .../thread/AgentActivityCluster.tsx | 163 +++- .../components/thread/StreamErrorNotice.tsx | 7 +- .../src/components/thread/ThreadComposer.tsx | 279 ++++--- webui/src/components/thread/ThreadHeader.tsx | 8 +- .../src/components/thread/ThreadMessages.tsx | 61 +- webui/src/components/thread/ThreadShell.tsx | 189 ++--- .../src/components/thread/ThreadViewport.tsx | 8 +- .../components/thread/WorkspaceControls.tsx | 326 ++++++++ webui/src/components/ui/alert-dialog.tsx | 4 - webui/src/components/ui/dialog.tsx | 5 - webui/src/components/ui/dropdown-menu.tsx | 34 +- webui/src/components/ui/sheet.tsx | 12 +- webui/src/globals.css | 8 + webui/src/hooks/useNanobotStream.ts | 176 +++- webui/src/hooks/useSessions.ts | 9 +- webui/src/hooks/useSidebarState.ts | 2 + webui/src/i18n/config.ts | 2 +- webui/src/i18n/locales/en/common.json | 132 ++- webui/src/i18n/locales/es/common.json | 134 ++- webui/src/i18n/locales/fr/common.json | 134 ++- webui/src/i18n/locales/id/common.json | 134 ++- webui/src/i18n/locales/ja/common.json | 134 ++- webui/src/i18n/locales/ko/common.json | 134 ++- webui/src/i18n/locales/vi/common.json | 134 ++- webui/src/i18n/locales/zh-CN/common.json | 130 ++- webui/src/i18n/locales/zh-TW/common.json | 134 ++- webui/src/lib/api.ts | 80 ++ webui/src/lib/bootstrap.ts | 18 +- webui/src/lib/chat-groups.ts | 372 +++++++++ webui/src/lib/nanobot-client.ts | 68 +- webui/src/lib/provider-brand.ts | 10 +- webui/src/lib/runtime.ts | 211 +++++ webui/src/lib/types.ts | 105 ++- webui/src/lib/workspace.ts | 56 ++ .../src/tests/agent-activity-cluster.test.tsx | 98 +++ webui/src/tests/api.test.ts | 106 +++ webui/src/tests/app-layout.test.tsx | 57 +- webui/src/tests/bootstrap.test.ts | 23 + webui/src/tests/chat-list.test.tsx | 257 ++++++ webui/src/tests/i18n.test.tsx | 21 +- webui/src/tests/nanobot-client.test.ts | 116 ++- webui/src/tests/provider-brand.test.ts | 6 + webui/src/tests/settings-view.test.tsx | 109 ++- webui/src/tests/thread-composer.test.tsx | 254 +++++- webui/src/tests/thread-messages.test.tsx | 146 ++++ webui/src/tests/thread-shell.test.tsx | 231 +++++- webui/src/tests/useNanobotStream.test.tsx | 146 +++- webui/src/tests/useSessions.test.tsx | 25 + webui/vite.config.ts | 23 +- 111 files changed, 9972 insertions(+), 1822 deletions(-) create mode 100644 nanobot/security/workspace_access.py create mode 100644 nanobot/security/workspace_policy.py create mode 100644 nanobot/webui/workspaces.py create mode 100644 tests/agent/test_loop_direct_websocket_status.py create mode 100644 tests/agent/test_workspace_scope.py create mode 100644 tests/security/test_workspace_policy.py create mode 100644 tests/security/test_workspace_sandbox.py create mode 100644 tests/utils/test_webui_workspaces.py create mode 100644 webui/src/components/thread/WorkspaceControls.tsx create mode 100644 webui/src/lib/chat-groups.ts create mode 100644 webui/src/lib/runtime.ts create mode 100644 webui/src/lib/workspace.ts create mode 100644 webui/src/tests/bootstrap.test.ts create mode 100644 webui/src/tests/chat-list.test.tsx diff --git a/.gitignore b/.gitignore index 19b129a26..cddec5083 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ .env .web .orion +nanobot-desktop/ +desktop/ # Claude / AI assistant artifacts docs/superpowers/ diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 68ac5f324..9aa7395c3 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -68,11 +68,13 @@ class ContextBuilder: skill_names: list[str] | None = None, channel: str | None = None, session_summary: str | None = None, + workspace: Path | None = None, ) -> str: """Build the system prompt from identity, bootstrap files, memory, and skills.""" - parts = [self._get_identity(channel=channel)] + root = workspace or self.workspace + parts = [self._get_identity(channel=channel, workspace=root)] - bootstrap = self._load_bootstrap_files() + bootstrap = self._load_bootstrap_files(root) if bootstrap: parts.append(bootstrap) @@ -106,9 +108,10 @@ class ContextBuilder: return "\n\n---\n\n".join(parts) - def _get_identity(self, channel: str | None = None) -> str: + def _get_identity(self, channel: str | None = None, workspace: Path | None = None) -> str: """Get the core identity section.""" - workspace_path = str(self.workspace.expanduser().resolve()) + root = workspace or self.workspace + workspace_path = str(root.expanduser().resolve()) system = platform.system() runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" @@ -152,12 +155,13 @@ class ContextBuilder: return _to_blocks(left) + _to_blocks(right) - def _load_bootstrap_files(self) -> str: + def _load_bootstrap_files(self, workspace: Path | None = None) -> str: """Load all bootstrap files from workspace.""" parts = [] + root = workspace or self.workspace for filename in self.BOOTSTRAP_FILES: - file_path = self.workspace / filename + file_path = root / filename if file_path.exists(): content = file_path.read_text(encoding="utf-8") parts.append(f"## {filename}\n\n{content}") @@ -185,11 +189,18 @@ class ContextBuilder: session_summary: str | None = None, session_metadata: Mapping[str, Any] | None = None, current_runtime_lines: Sequence[str] | None = None, + workspace: Path | None = None, + runtime_state: Any | None = None, + inbound_message: Any | None = None, + skip_runtime_lines: bool = False, ) -> list[dict[str, Any]]: """Build the complete message list for an LLM call.""" + root = workspace or self.workspace extra = [ *goal_state_runtime_lines(session_metadata), ] + if runtime_state is not None and inbound_message is not None: + extra.extend(runtime_lines(runtime_state, inbound_message, root, skip=skip_runtime_lines)) if current_runtime_lines: extra.extend(line for line in current_runtime_lines if line) runtime_ctx = self._build_runtime_context( @@ -210,7 +221,15 @@ class ContextBuilder: else: merged = user_content + [{"type": "text", "text": runtime_ctx}] messages = [ - {"role": "system", "content": self.build_system_prompt(skill_names, channel=channel, session_summary=session_summary)}, + { + "role": "system", + "content": self.build_system_prompt( + skill_names, + channel=channel, + session_summary=session_summary, + workspace=root, + ), + }, *history, ] if messages[-1].get("role") == current_role: diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 5a0985cbd..2be846ceb 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -25,8 +25,14 @@ from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRun from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, reset_file_states from nanobot.agent.tools.message import MessageTool +from nanobot.agent.tools.context import RequestContext, bind_request_context, reset_request_context from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.self import MyTool +from nanobot.security.workspace_access import ( + WorkspaceScopeResolver, + bind_workspace_scope, + reset_workspace_scope, +) from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.command import CommandContext, CommandRouter, register_builtin_commands @@ -114,7 +120,6 @@ class TurnContext: pending_queue: asyncio.Queue | None = None pending_summary: str | None = None - turn_wall_started_at: float = field(default_factory=time.time) turn_latency_ms: int | None = None @@ -241,6 +246,10 @@ class AgentLoop: self._image_generation_provider_configs["openrouter"] = image_generation_provider_config self.cron_service = cron_service self.restrict_to_workspace = restrict_to_workspace + self.workspace_scopes = WorkspaceScopeResolver( + default_workspace=workspace, + default_restrict_to_workspace=restrict_to_workspace, + ) self._start_time = time.time() self._last_usage: dict[str, int] = {} self._pending_turn_latency_ms: dict[str, int] = {} @@ -470,6 +479,7 @@ class AgentLoop: provider_snapshot_loader=self._provider_snapshot_loader, image_generation_provider_configs=self._image_generation_provider_configs, timezone=self.context.timezone or "UTC", + workspace_sandbox=self.workspace_scopes.sandbox_status, ) loader = ToolLoader() registered = loader.load(ctx, self.tools) @@ -493,7 +503,7 @@ class AgentLoop: session_key: str | None = None, ) -> None: """Update context for all tools that need routing info.""" - from nanobot.agent.tools.context import ContextAware, RequestContext + from nanobot.agent.tools.context import ContextAware if session_key is not None: effective_key = session_key @@ -575,6 +585,7 @@ class AgentLoop: pending_summary: str | None, ) -> list[dict[str, Any]]: """Build the initial message list for the LLM turn.""" + scope = self.workspace_scopes.for_message(msg, session.metadata) return self.context.build_messages( history=history, current_message=image_generation_prompt(msg.content, msg.metadata), @@ -583,7 +594,10 @@ class AgentLoop: chat_id=self._runtime_chat_id(msg), sender_id=msg.sender_id, session_summary=pending_summary, - session_metadata=session.metadata, current_runtime_lines=agent_context.runtime_lines(self, msg, self.context.workspace), + session_metadata=session.metadata, + workspace=scope.project_path, + runtime_state=self, + inbound_message=msg, ) async def _dispatch_command_inline( @@ -733,7 +747,21 @@ class AgentLoop: return items active_session_key = session.key if session else session_key + effective_scope = self.workspace_scopes.for_turn( + channel=channel, + message_metadata=metadata, + session_metadata=session.metadata if session is not None else None, + ) + request_ctx = RequestContext( + channel=channel, + chat_id=chat_id, + message_id=message_id, + session_key=active_session_key, + metadata=dict(metadata or {}), + ) file_state_token = bind_file_states(self._file_state_store.for_session(active_session_key)) + request_token = bind_request_context(request_ctx) + workspace_token = bind_workspace_scope(effective_scope) # Build continuation message that embeds the active goal objective so # the LLM can see it even if earlier Runtime Context was truncated. _goal_lines = goal_state_runtime_lines(session.metadata if session is not None else None) @@ -753,7 +781,7 @@ class AgentLoop: hook=hook, error_message="Sorry, I encountered an error calling the AI model.", concurrent_tools=True, - workspace=self.workspace, + workspace=effective_scope.project_path, session_key=session.key if session else None, context_window_tokens=self.context_window_tokens, context_block_limit=self.context_block_limit, @@ -774,6 +802,8 @@ class AgentLoop: goal_continue_message=_goal_continue, )) finally: + reset_workspace_scope(workspace_token) + reset_request_context(request_token) reset_file_states(file_state_token) self._last_usage = result.usage if result.stop_reason == "max_iterations": @@ -1063,6 +1093,7 @@ class AgentLoop: } history = session.get_history(**_hist_kwargs) current_role = "assistant" if is_subagent else "user" + workspace_scope = self.workspace_scopes.for_message(msg, session.metadata) messages = self.context.build_messages( history=history, @@ -1072,7 +1103,11 @@ class AgentLoop: current_role=current_role, sender_id=msg.sender_id, session_summary=pending, - session_metadata=session.metadata, current_runtime_lines=agent_context.runtime_lines(self, msg, self.context.workspace, skip=is_subagent), + session_metadata=session.metadata, + workspace=workspace_scope.project_path, + runtime_state=self, + inbound_message=msg, + skip_runtime_lines=is_subagent, ) t_wall = time.time() final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop( @@ -1248,6 +1283,7 @@ class AgentLoop: if ctx.session is None: ctx.session = self.sessions.get_or_create(ctx.session_key) mark_webui_session(ctx.session, msg.metadata) + self.workspace_scopes.persist_message_scope(ctx.session, msg) if self._restore_runtime_checkpoint(ctx.session): self.sessions.save(ctx.session) @@ -1315,7 +1351,10 @@ class AgentLoop: ) ctx.initial_messages = self._build_initial_messages( - ctx.msg, ctx.session, ctx.history, ctx.pending_summary + ctx.msg, + ctx.session, + ctx.history, + ctx.pending_summary, ) ctx.user_persisted_early = self._persist_user_message_early( ctx.msg, ctx.session @@ -1618,10 +1657,16 @@ class AgentLoop: channel=channel, sender_id="user", chat_id=chat_id, content=content, media=media or [], ) - return await self._process_message( - msg, - session_key=session_key, - on_progress=on_progress, - on_stream=on_stream, - on_stream_end=on_stream_end, - ) + try: + return await self._process_message( + msg, + session_key=session_key, + on_progress=on_progress, + on_stream=on_stream, + on_stream_end=on_stream_end, + ) + finally: + if channel == "websocket": + await self._webui_turns.publish_run_status(msg, "idle") + self._pending_turn_latency_ms.pop(session_key, None) + self._webui_turns.discard(session_key) diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 6b39f66d0..8a752c6f7 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -16,6 +16,12 @@ from nanobot.agent.tools.context import ToolContext from nanobot.agent.tools.file_state import FileStates from nanobot.agent.tools.loader import ToolLoader from nanobot.agent.tools.registry import ToolRegistry +from nanobot.security.workspace_access import ( + WorkspaceScope, + bind_workspace_scope, + reset_workspace_scope, + workspace_sandbox_status, +) from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.config.schema import AgentDefaults, ToolsConfig @@ -128,6 +134,10 @@ class SubagentManager: config=cfg, workspace=str(root.resolve()), file_state_store=FileStates(), + workspace_sandbox=workspace_sandbox_status( + restrict_to_workspace=cfg.restrict_to_workspace, + workspace=root, + ), ) ToolLoader().load(ctx, registry, scope="subagent") return registry @@ -146,6 +156,7 @@ class SubagentManager: session_key: str | None = None, origin_message_id: str | None = None, temperature: float | None = None, + workspace_scope: WorkspaceScope | None = None, ) -> str: """Spawn a subagent to execute a task in the background.""" task_id = str(uuid.uuid4())[:8] @@ -162,7 +173,14 @@ class SubagentManager: bg_task = asyncio.create_task( self._run_subagent( - task_id, task, display_label, origin, status, origin_message_id, temperature + task_id, + task, + display_label, + origin, + status, + origin_message_id, + temperature, + workspace_scope, ) ) self._running_tasks[task_id] = bg_task @@ -191,6 +209,7 @@ class SubagentManager: status: SubagentStatus, origin_message_id: str | None = None, temperature: float | None = None, + workspace_scope: WorkspaceScope | None = None, ) -> None: """Execute the subagent task and announce the result.""" logger.info("Subagent [{}] starting task: {}", task_id, label) @@ -200,8 +219,13 @@ class SubagentManager: status.iteration = payload.get("iteration", status.iteration) try: - tools = self._build_tools() - system_prompt = self._build_subagent_prompt() + root = workspace_scope.project_path if workspace_scope is not None else self.workspace + cfg = None + if workspace_scope is not None: + cfg = self._subagent_tools_config() + cfg.restrict_to_workspace = workspace_scope.restrict_to_workspace + tools = self._build_tools(workspace=root, tools_config=cfg) + system_prompt = self._build_subagent_prompt(workspace=root) messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": task}, @@ -213,21 +237,27 @@ class SubagentManager: if self._llm_wall_timeout_for_session else None ) - result = await self.runner.run(AgentRunSpec( - initial_messages=messages, - tools=tools, - model=self.model, - temperature=temperature, - max_iterations=self.max_iterations, - max_tool_result_chars=self.max_tool_result_chars, - hook=_SubagentHook(task_id, status), - max_iterations_message="Task completed but no final response was generated.", - error_message=None, - fail_on_tool_error=True, - checkpoint_callback=_on_checkpoint, - session_key=sess_key, - llm_timeout_s=llm_timeout, - )) + token = bind_workspace_scope(workspace_scope) if workspace_scope is not None else None + try: + result = await self.runner.run(AgentRunSpec( + initial_messages=messages, + tools=tools, + model=self.model, + temperature=temperature, + max_iterations=self.max_iterations, + max_tool_result_chars=self.max_tool_result_chars, + hook=_SubagentHook(task_id, status), + max_iterations_message="Task completed but no final response was generated.", + error_message=None, + fail_on_tool_error=True, + checkpoint_callback=_on_checkpoint, + session_key=sess_key, + workspace=root, + llm_timeout_s=llm_timeout, + )) + finally: + if token is not None: + reset_workspace_scope(token) status.phase = "done" status.stop_reason = result.stop_reason @@ -321,20 +351,21 @@ class SubagentManager: lines.append(f"- {result.error}") return "\n".join(lines) or (result.error or "Error: subagent execution failed.") - def _build_subagent_prompt(self) -> str: + def _build_subagent_prompt(self, workspace: Path | None = None) -> str: """Build a focused system prompt for the subagent.""" from nanobot.agent.context import ContextBuilder from nanobot.agent.skills import SkillsLoader time_ctx = ContextBuilder._build_runtime_context(None, None) + root = workspace or self.workspace skills_summary = SkillsLoader( - self.workspace, + root, disabled_skills=self.disabled_skills, ).build_skills_summary() return render_template( "agent/subagent_system.md", time_ctx=time_ctx, - workspace=str(self.workspace), + workspace=str(root), skills_summary=skills_summary or "", ) diff --git a/nanobot/agent/tools/apply_patch.py b/nanobot/agent/tools/apply_patch.py index ac524f7fc..a1acd4c90 100644 --- a/nanobot/agent/tools/apply_patch.py +++ b/nanobot/agent/tools/apply_patch.py @@ -88,11 +88,11 @@ def _format_summary(summary: _PatchSummary) -> str: items=ObjectSchema( path=StringSchema("Relative path to the file to edit."), action=StringSchema( - "Operation type: replace (find and replace text), add (append new content or create file), delete (remove text).", - enum=["replace", "add", "delete"], + "Operation type: replace or add.", + enum=["replace", "add"], ), old_text=StringSchema( - "Exact text to search for in the file. Required for replace and delete.", + "Exact text to search for in the file. Required for replace.", nullable=True, ), new_text=StringSchema( @@ -124,7 +124,8 @@ class ApplyPatchTool(_FsTool): def description(self) -> str: return ( "Default tool for code edits. Supports multi-file changes in a single call. " - "Provide a list of structured edits, each specifying a file path, action (replace/add/delete), and the text to change. " + "Provide a list of structured edits, each specifying a file path, action " + "(replace/add), and the exact text to change. " "Paths must be relative. Set dry_run=true to validate and preview without writing files. " "Use edit_file only for small exact replacements on a single file." ) @@ -140,7 +141,6 @@ class ApplyPatchTool(_FsTool): raise _PatchError("must provide edits") writes: dict[Path, str] = {} - deletes: set[Path] = set() summaries: list[_PatchSummary] = [] for edit in edits: @@ -183,7 +183,6 @@ class ApplyPatchTool(_FsTool): if uses_crlf: new_norm = new_norm.replace("\n", "\r\n") writes[source] = new_norm - deletes.discard(source) added, deleted = _line_diff_stats(content, new_norm) action_name = "update" else: @@ -191,7 +190,6 @@ class ApplyPatchTool(_FsTool): if new_norm and not new_norm.endswith("\n"): new_norm += "\n" writes[source] = new_norm - deletes.discard(source) added = _text_line_count(new_norm) deleted = 0 action_name = "add" @@ -246,7 +244,6 @@ class ApplyPatchTool(_FsTool): new_norm = new_norm.replace("\n", "\r\n") writes[source] = new_norm - deletes.discard(source) added, deleted = _line_diff_stats(content, new_norm) summaries.append( _PatchSummary( @@ -254,62 +251,6 @@ class ApplyPatchTool(_FsTool): ) ) - elif action == "delete": - old_text = edit.get("old_text") or "" - if not old_text: - raise _PatchError(f"old_text required for delete: {path}") - - pending = writes.get(source) - if pending is not None: - content = pending - elif source.exists(): - raw = source.read_bytes() - try: - content = raw.decode("utf-8") - except UnicodeDecodeError: - raise _PatchError(f"file is not UTF-8 text: {path}") - else: - raise _PatchError(f"file to update does not exist: {path}") - - if pending is None and not source.is_file(): - raise _PatchError(f"path to update is not a file: {path}") - - uses_crlf = "\r\n" in content - norm_content = content.replace("\r\n", "\n") - norm_old = old_text.replace("\r\n", "\n") - - pos = norm_content.find(norm_old) - if pos < 0: - raise _PatchError(f"old_text not found in {path}") - if norm_content.find(norm_old, pos + 1) >= 0: - raise _PatchError(f"old_text appears multiple times in {path}") - - if norm_old == norm_content: - deletes.add(source) - writes.pop(source, None) - added, deleted = 0, _text_line_count(content) - summaries.append( - _PatchSummary( - action="delete", path=path, added=added, deleted=deleted - ) - ) - else: - new_norm = ( - norm_content[:pos] + norm_content[pos + len(norm_old) :] - ) - if new_norm and not new_norm.endswith("\n"): - new_norm += "\n" - if uses_crlf: - new_norm = new_norm.replace("\n", "\r\n") - writes[source] = new_norm - deletes.discard(source) - added, deleted = _line_diff_stats(content, new_norm) - summaries.append( - _PatchSummary( - action="update", path=path, added=added, deleted=deleted - ) - ) - else: raise _PatchError(f"unknown action: {action}") @@ -319,13 +260,10 @@ class ApplyPatchTool(_FsTool): ) backups: dict[Path, bytes | None] = {} - for path in set(writes) | deletes: + for path in writes: backups[path] = path.read_bytes() if path.exists() else None try: - for path in deletes: - if path.exists(): - path.unlink() for path, content in writes.items(): path.parent.mkdir(parents=True, exist_ok=True) path.write_text(content, encoding="utf-8", newline="") @@ -339,7 +277,7 @@ class ApplyPatchTool(_FsTool): path.write_bytes(data) raise - for path in set(writes) | deletes: + for path in writes: self._file_states.record_write(path) return "Patch applied:\n" + "\n".join( _format_summary(summary) for summary in summaries diff --git a/nanobot/agent/tools/cli_apps.py b/nanobot/agent/tools/cli_apps.py index 3d0c109ab..9bee1a34a 100644 --- a/nanobot/agent/tools/cli_apps.py +++ b/nanobot/agent/tools/cli_apps.py @@ -9,6 +9,7 @@ from pydantic import Field from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.schema import ArraySchema, BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.security.workspace_access import current_tool_workspace from nanobot.apps.cli import CliAppError, CliAppManager, CliAppsRuntimeConfig from nanobot.config.schema import Base @@ -113,7 +114,12 @@ class CliAppsTool(Tool): working_dir: str | None = None, timeout: int | None = None, ) -> str: - manager = CliAppManager(workspace=self.workspace, runtime=self.runtime) + access = current_tool_workspace( + self.workspace, + restrict_to_workspace=self.restrict_to_workspace, + ) + workspace = access.project_path or self.workspace + manager = CliAppManager(workspace=workspace, runtime=self.runtime) try: return manager.run( name, @@ -121,7 +127,7 @@ class CliAppsTool(Tool): json_output=bool(json), working_dir=working_dir, timeout=timeout, - restrict_to_workspace=self.restrict_to_workspace, + restrict_to_workspace=access.restrict_to_workspace, ) except CliAppError as exc: return f"Error: {exc.message}" diff --git a/nanobot/agent/tools/context.py b/nanobot/agent/tools/context.py index bd9898a02..61aa8ed7c 100644 --- a/nanobot/agent/tools/context.py +++ b/nanobot/agent/tools/context.py @@ -1,9 +1,15 @@ """Runtime context for tool construction.""" from __future__ import annotations +from contextvars import ContextVar, Token from dataclasses import dataclass, field from typing import Any, Callable, Protocol, runtime_checkable +_CURRENT_REQUEST_CONTEXT: ContextVar["RequestContext | None"] = ContextVar( + "nanobot_tool_request_context", + default=None, +) + @dataclass(frozen=True) class RequestContext: @@ -21,6 +27,23 @@ class ContextAware(Protocol): ... +def bind_request_context(ctx: RequestContext) -> Token[RequestContext | None]: + return _CURRENT_REQUEST_CONTEXT.set(ctx) + + +def reset_request_context(token: Token[RequestContext | None]) -> None: + _CURRENT_REQUEST_CONTEXT.reset(token) + + +def current_request_context() -> RequestContext | None: + return _CURRENT_REQUEST_CONTEXT.get() + + +def current_request_session_key() -> str | None: + ctx = current_request_context() + return ctx.session_key if ctx else None + + @dataclass class ToolContext: config: Any @@ -33,3 +56,4 @@ class ToolContext: provider_snapshot_loader: Callable[[], Any] | None = None image_generation_provider_configs: dict[str, Any] | None = None timezone: str = "UTC" + workspace_sandbox: Any | None = None diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py index c23d175e2..8b53f250f 100644 --- a/nanobot/agent/tools/exec_session.py +++ b/nanobot/agent/tools/exec_session.py @@ -10,6 +10,7 @@ from contextlib import suppress from dataclasses import dataclass from typing import Any +from nanobot.agent.tools.context import current_request_session_key from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema @@ -43,6 +44,7 @@ class ExecSessionInfo: idle_s: float remaining_s: float returncode: int | None + owner_session_key: str | None = None class _ExecSession: @@ -54,11 +56,13 @@ class _ExecSession: command: str, cwd: str, timeout: int | None, + owner_session_key: str | None = None, ) -> None: self.session_id = session_id self.process = process self.command = command self.cwd = cwd + self.owner_session_key = owner_session_key self.started_at = time.monotonic() # timeout None/0 means no limit; an infinite deadline is never reached. self.deadline = time.monotonic() + timeout if timeout else float("inf") @@ -175,6 +179,7 @@ class ExecSessionManager: login: bool, yield_time_ms: int, max_output_chars: int, + owner_session_key: str | None = None, ) -> tuple[str, _SessionPoll]: async with self._lock: await self._cleanup_locked() @@ -188,6 +193,7 @@ class ExecSessionManager: command=command, cwd=cwd, timeout=timeout, + owner_session_key=owner_session_key, ) self._sessions[session_id] = session @@ -206,12 +212,19 @@ class ExecSessionManager: terminate: bool, yield_time_ms: int, max_output_chars: int, + owner_session_key: str | None = None, ) -> _SessionPoll: async with self._lock: await self._cleanup_locked() session = self._sessions.get(session_id) if session is None: raise KeyError(session_id) + if ( + owner_session_key + and session.owner_session_key + and session.owner_session_key != owner_session_key + ): + raise KeyError(session_id) if chars: error = await session.write(chars) @@ -236,7 +249,7 @@ class ExecSessionManager: self._sessions.pop(session_id, None) return poll - async def list(self) -> list[ExecSessionInfo]: + async def list(self, *, owner_session_key: str | None = None) -> list[ExecSessionInfo]: async with self._lock: await self._cleanup_locked() now = time.monotonic() @@ -249,8 +262,12 @@ class ExecSessionManager: idle_s=max(0.0, now - session.last_access), remaining_s=max(0.0, session.deadline - now), returncode=session.process.returncode, + owner_session_key=session.owner_session_key, ) for session_id, session in sorted(self._sessions.items()) + if not owner_session_key + or not session.owner_session_key + or session.owner_session_key == owner_session_key ] async def _cleanup_locked(self) -> None: @@ -477,6 +494,7 @@ class WriteStdinTool(Tool): terminate=terminate, yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), max_output_chars=output_limit, + owner_session_key=current_request_session_key(), ) return format_session_poll(session_id, poll) except KeyError: @@ -510,6 +528,7 @@ class WriteStdinTool(Tool): terminate=terminate if first else False, yield_time_ms=step_ms, max_output_chars=max_output_chars, + owner_session_key=current_request_session_key(), ) first = False if poll.output: @@ -573,7 +592,9 @@ class ListExecSessionsTool(Tool): async def execute(self, **kwargs: Any) -> str: try: - sessions = await self._manager.list() + sessions = await self._manager.list( + owner_session_key=current_request_session_key(), + ) if not sessions: return "No active exec sessions." lines = [] diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index fa63e5f66..6e439495a 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -10,6 +10,7 @@ from typing import Any from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.file_state import FileStates, _hash_file, current_file_states from nanobot.agent.tools.path_utils import resolve_workspace_path +from nanobot.security.workspace_access import current_tool_workspace from nanobot.agent.tools.schema import ( BooleanSchema, IntegerSchema, @@ -28,10 +29,18 @@ class _FsTool(Tool): allowed_dir: Path | None = None, extra_allowed_dirs: list[Path] | None = None, file_states: FileStates | None = None, + restrict_to_workspace: bool | None = None, + sandbox_restricts_workspace: bool = False, ): self._workspace = workspace self._allowed_dir = allowed_dir self._extra_allowed_dirs = extra_allowed_dirs + self._restrict_to_workspace = ( + bool(restrict_to_workspace) + if restrict_to_workspace is not None + else allowed_dir is not None + ) + self._sandbox_restricts_workspace = sandbox_restricts_workspace # Explicit state is used by isolated runners like Dream/subagents. # Main AgentLoop tools leave this unset and resolve state from the # current async task, which keeps shared tool instances session-safe. @@ -46,13 +55,16 @@ class _FsTool(Tool): ctx.config.restrict_to_workspace or ctx.config.exec.sandbox ) + sandbox_restricts = bool(ctx.config.exec.sandbox) allowed_dir = Path(ctx.workspace) if restrict else None - extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None + extra_read = [BUILTIN_SKILLS_DIR] return cls( workspace=Path(ctx.workspace), allowed_dir=allowed_dir, extra_allowed_dirs=extra_read, file_states=ctx.file_state_store, + restrict_to_workspace=ctx.config.restrict_to_workspace, + sandbox_restricts_workspace=sandbox_restricts, ) @property @@ -62,13 +74,21 @@ class _FsTool(Tool): return current_file_states(self._fallback_file_states) def _resolve(self, path: str) -> Path: + access = current_tool_workspace( + self._workspace, + restrict_to_workspace=self._restrict_to_workspace, + sandbox_restricts_workspace=self._sandbox_restricts_workspace, + ) return resolve_workspace_path( path, - self._workspace, - self._allowed_dir, + access.project_path, + access.allowed_root, self._extra_allowed_dirs, ) + def _display_workspace(self) -> Path | None: + return current_tool_workspace(self._workspace).project_path + # --------------------------------------------------------------------------- # read_file diff --git a/nanobot/agent/tools/image_generation.py b/nanobot/agent/tools/image_generation.py index a194d0fee..4471f999e 100644 --- a/nanobot/agent/tools/image_generation.py +++ b/nanobot/agent/tools/image_generation.py @@ -14,6 +14,7 @@ from nanobot.agent.tools.schema import ( StringSchema, tool_parameters_schema, ) +from nanobot.security.workspace_access import current_tool_workspace from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base from nanobot.providers.image_generation import ( @@ -21,6 +22,7 @@ from nanobot.providers.image_generation import ( ImageGenerationProvider, get_image_gen_provider, ) +from nanobot.security.workspace_policy import WorkspaceBoundaryError, resolve_allowed_path from nanobot.utils.artifacts import ( ArtifactError, generated_image_tool_result, @@ -131,18 +133,22 @@ class ImageGenerationTool(Tool): return cls(**kwargs) def _resolve_reference_image(self, value: str) -> str: - raw_path = Path(value).expanduser() - path = raw_path if raw_path.is_absolute() else self.workspace / raw_path + access = current_tool_workspace(self.workspace, restrict_to_workspace=True) + workspace = access.project_path or self.workspace try: - resolved = path.resolve(strict=True) - except OSError as exc: - raise ImageGenerationError(f"reference image not found: {value}") from exc - - allowed_roots = [self.workspace.resolve(), get_media_dir().resolve()] - if not any(_is_relative_to(resolved, root) for root in allowed_roots): + resolved = resolve_allowed_path( + value, + workspace=workspace, + allowed_root=access.allowed_root, + extra_allowed_roots=[get_media_dir()] if access.allowed_root is not None else None, + strict=True, + ) + except WorkspaceBoundaryError as exc: raise ImageGenerationError( "reference_images must be inside the workspace or nanobot media directory" - ) + ) from exc + except OSError as exc: + raise ImageGenerationError(f"reference image not found: {value}") from exc if not resolved.is_file(): raise ImageGenerationError(f"reference image is not a file: {value}") raw = resolved.read_bytes() @@ -201,11 +207,3 @@ class ImageGenerationTool(Tool): return generated_image_tool_result(artifacts) except (ArtifactError, ImageGenerationError, OSError) as exc: return f"Error: {exc}" - - -def _is_relative_to(path: Path, root: Path) -> bool: - try: - path.relative_to(root) - except ValueError: - return False - return True diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 63b45c38f..de0fcb1c5 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -8,6 +8,7 @@ from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.context import ContextAware, RequestContext from nanobot.agent.tools.path_utils import resolve_workspace_path from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema +from nanobot.security.workspace_access import current_tool_workspace from nanobot.bus.events import OutboundMessage from nanobot.config.paths import get_workspace_path @@ -149,15 +150,19 @@ class MessageTool(Tool, ContextAware): def _resolve_media(self, media: list[str]) -> list[str]: """Resolve local media attachments and enforce workspace restriction when enabled.""" resolved: list[str] = [] - allowed_dir = self._workspace if self._restrict_to_workspace else None + access = current_tool_workspace( + self._workspace, + restrict_to_workspace=self._restrict_to_workspace, + ) + workspace = access.project_path or self._workspace for p in media: if p.startswith(("http://", "https://")): resolved.append(p) - elif not self._restrict_to_workspace: + elif not access.restrict_to_workspace: path = Path(p).expanduser() - resolved.append(p if path.is_absolute() else str(self._workspace / path)) + resolved.append(p if path.is_absolute() else str(workspace / path)) else: - resolved.append(str(resolve_workspace_path(p, self._workspace, allowed_dir))) + resolved.append(str(resolve_workspace_path(p, workspace, access.allowed_root))) return resolved async def execute( diff --git a/nanobot/agent/tools/path_utils.py b/nanobot/agent/tools/path_utils.py index a98fa3729..5d618cd51 100644 --- a/nanobot/agent/tools/path_utils.py +++ b/nanobot/agent/tools/path_utils.py @@ -3,21 +3,15 @@ from pathlib import Path from nanobot.config.paths import get_media_dir - -WORKSPACE_BOUNDARY_NOTE = ( - " (this is a hard policy boundary, not a transient failure; " - "do not retry with shell tricks or alternative tools, and ask " - "the user how to proceed if the resource is genuinely required)" +from nanobot.security.workspace_policy import ( + is_path_within, + resolve_allowed_path, ) def is_under(path: Path, directory: Path) -> bool: """Return True when path resolves under directory.""" - try: - path.relative_to(directory.resolve()) - return True - except ValueError: - return False + return is_path_within(path, directory) def resolve_workspace_path( @@ -27,16 +21,10 @@ def resolve_workspace_path( extra_allowed_dirs: list[Path] | None = None, ) -> Path: """Resolve path against workspace and enforce allowed directory containment.""" - p = Path(path).expanduser() - if not p.is_absolute() and workspace: - p = workspace / p - resolved = p.resolve() - if allowed_dir: - media_path = get_media_dir().resolve() - all_dirs = [allowed_dir, media_path, *(extra_allowed_dirs or [])] - if not any(is_under(resolved, d) for d in all_dirs): - raise PermissionError( - f"Path {path} is outside allowed directory {allowed_dir}" - + WORKSPACE_BOUNDARY_NOTE - ) - return resolved + extra_roots = [get_media_dir(), *(extra_allowed_dirs or [])] if allowed_dir else None + return resolve_allowed_path( + path, + workspace=workspace, + allowed_root=allowed_dir, + extra_allowed_roots=extra_roots, + ) diff --git a/nanobot/agent/tools/runtime_state.py b/nanobot/agent/tools/runtime_state.py index b3c24ac46..449c7fe08 100644 --- a/nanobot/agent/tools/runtime_state.py +++ b/nanobot/agent/tools/runtime_state.py @@ -42,6 +42,9 @@ class RuntimeState(Protocol): @property def exec_config(self) -> Any: ... + @property + def workspace_sandbox(self) -> Any: ... + @property def subagents(self) -> Any: ... diff --git a/nanobot/agent/tools/search.py b/nanobot/agent/tools/search.py index 0febb122c..30fefbb94 100644 --- a/nanobot/agent/tools/search.py +++ b/nanobot/agent/tools/search.py @@ -101,9 +101,10 @@ class _SearchTool(_FsTool): _IGNORE_DIRS = set(ListDirTool._IGNORE_DIRS) def _display_path(self, target: Path, root: Path) -> str: - if self._workspace: + workspace = self._display_workspace() + if workspace: with suppress(ValueError): - return target.relative_to(self._workspace).as_posix() + return target.relative_to(workspace).as_posix() return target.relative_to(root).as_posix() def _iter_files(self, root: Path) -> Iterable[Path]: diff --git a/nanobot/agent/tools/self.py b/nanobot/agent/tools/self.py index 2712df0dc..f12f83b3d 100644 --- a/nanobot/agent/tools/self.py +++ b/nanobot/agent/tools/self.py @@ -3,16 +3,18 @@ from __future__ import annotations import time -from typing import Any +from typing import TYPE_CHECKING, Any from loguru import logger -from nanobot.agent.subagent import SubagentStatus from nanobot.agent.tools.base import Tool from nanobot.agent.tools.context import ContextAware, RequestContext from nanobot.agent.tools.runtime_state import RuntimeState from nanobot.config.schema import Base +if TYPE_CHECKING: + from nanobot.agent.subagent import SubagentStatus + class MyToolConfig(Base): """Self-inspection tool configuration.""" @@ -33,6 +35,12 @@ def _has_real_attr(obj: Any, key: str) -> bool: return False +def _is_subagent_status(value: Any) -> bool: + from nanobot.agent.subagent import SubagentStatus + + return isinstance(value, SubagentStatus) + + class MyTool(Tool, ContextAware): """Check and set the agent loop's runtime configuration.""" @@ -68,6 +76,7 @@ class MyTool(Tool, ContextAware): "_current_iteration", # updated by runner only "exec_config", # inspect allowed (e.g. check sandbox), modify blocked "web_config", # inspect allowed (e.g. check enable), modify blocked + "workspace_sandbox", # read-only view of workspace enforcement level }) _DENIED_ATTRS = frozenset({ @@ -214,7 +223,7 @@ class MyTool(Tool, ContextAware): # ------------------------------------------------------------------ @staticmethod - def _format_status(st: SubagentStatus, indent: str = " ") -> str: + def _format_status(st: "SubagentStatus", indent: str = " ") -> str: elapsed = time.monotonic() - st.started_at tool_summary = ", ".join( f"{e.get('name', '?')}({e.get('status', '?')})" for e in st.tool_events[-5:] @@ -232,14 +241,14 @@ class MyTool(Tool, ContextAware): @staticmethod def _format_value(val: Any, key: str = "") -> str: - if isinstance(val, SubagentStatus): + if _is_subagent_status(val): header = f"Subagent [{val.task_id}] '{val.label}'" detail = MyTool._format_status(val, " ") return f"{header}\n task: {val.task_description}\n{detail}" # SubagentManager: delegate to its _task_statuses dict if hasattr(val, "_task_statuses") and isinstance(val._task_statuses, dict): return MyTool._format_value(val._task_statuses, key) - if isinstance(val, dict) and val and isinstance(next(iter(val.values())), SubagentStatus): + if isinstance(val, dict) and val and _is_subagent_status(next(iter(val.values()))): prefix = f"{key}: " if key else "" lines = [f"{prefix}{len(val)} subagent(s):"] for tid, st in val.items(): @@ -349,7 +358,7 @@ class MyTool(Tool, ContextAware): parts.append(self._format_value(getattr(state, k, None), k)) parts.append(self._format_value(state.model_preset, "model_preset")) # Other useful top-level keys shown in description - for k in ("workspace", "provider_retry_mode", "max_tool_result_chars", "_current_iteration", "web_config", "exec_config", "subagents"): + for k in ("workspace", "provider_retry_mode", "max_tool_result_chars", "_current_iteration", "web_config", "exec_config", "workspace_sandbox", "subagents"): if _has_real_attr(state, k): parts.append(self._format_value(getattr(state, k, None), k)) # Token usage diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 88c454bfa..082f8cce4 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -25,10 +25,13 @@ from nanobot.agent.tools.exec_session import ( clamp_session_int, format_session_poll, ) +from nanobot.agent.tools.context import current_request_session_key from nanobot.agent.tools.sandbox import wrap_command from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.security.workspace_access import current_scope_allows_loopback, current_tool_workspace from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base +from nanobot.security.workspace_policy import is_path_within _IS_WINDOWS = sys.platform == "win32" @@ -140,6 +143,7 @@ class ExecTool(Tool): working_dir=ctx.workspace, timeout=cfg.timeout, restrict_to_workspace=ctx.config.restrict_to_workspace, + webui_allow_local_service_access=ctx.config.webui_allow_local_service_access, sandbox=cfg.sandbox, path_append=cfg.path_append, allowed_env_keys=cfg.allowed_env_keys, @@ -154,6 +158,8 @@ class ExecTool(Tool): deny_patterns: list[str] | None = None, allow_patterns: list[str] | None = None, restrict_to_workspace: bool = False, + webui_allow_local_service_access: bool = True, + allow_local_preview_access: bool | None = None, sandbox: str = "", path_append: str = "", allowed_env_keys: list[str] | None = None, @@ -183,6 +189,9 @@ class ExecTool(Tool): ] self.allow_patterns = allow_patterns or [] self.restrict_to_workspace = restrict_to_workspace + if allow_local_preview_access is not None: + webui_allow_local_service_access = allow_local_preview_access + self.webui_allow_local_service_access = webui_allow_local_service_access self.path_append = path_append self.allowed_env_keys = allowed_env_keys or [] self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER @@ -313,6 +322,7 @@ class ExecTool(Tool): shell_program=prepared.shell_program, login=prepared.login, yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), + owner_session_key=current_request_session_key(), max_output_chars=clamp_session_int( max_output_chars, DEFAULT_MAX_OUTPUT_CHARS, @@ -346,29 +356,39 @@ class ExecTool(Tool): shell: str | None = None, login: bool | None = None, ) -> _PreparedCommand | str: - cwd = working_dir or self.working_dir or os.getcwd() + access = current_tool_workspace( + self.working_dir, + restrict_to_workspace=self.restrict_to_workspace, + sandbox_restricts_workspace=bool(self.sandbox), + ) + workspace_root = str(access.project_path) if access.project_path is not None else self.working_dir + cwd = working_dir or workspace_root or os.getcwd() # Prevent an LLM-supplied working_dir from escaping the configured # workspace when restrict_to_workspace is enabled (#2826). Without # this, a caller can pass working_dir="/etc" and then all absolute # paths under /etc would pass the _guard_command check that anchors # on cwd. - if self.restrict_to_workspace and self.working_dir: + if access.restrict_to_workspace and workspace_root: try: requested = Path(cwd).expanduser().resolve() - workspace_root = Path(self.working_dir).expanduser().resolve() + resolved_root = Path(workspace_root).expanduser().resolve() except Exception: return ( "Error: working_dir could not be resolved" + _WORKSPACE_BOUNDARY_NOTE ) - if requested != workspace_root and workspace_root not in requested.parents: + if not is_path_within(requested, resolved_root): return ( "Error: working_dir is outside the configured workspace" + _WORKSPACE_BOUNDARY_NOTE ) - guard_error = self._guard_command(command, cwd) + guard_error = self._guard_command( + command, + cwd, + restrict_to_workspace=access.restrict_to_workspace, + ) if guard_error: return guard_error @@ -379,7 +399,7 @@ class ExecTool(Tool): self.sandbox, ) else: - workspace = self.working_dir or cwd + workspace = workspace_root or cwd command = wrap_command(self.sandbox, command, workspace, cwd) cwd = str(Path(workspace).resolve()) @@ -528,7 +548,13 @@ class ExecTool(Tool): env[key] = val return env - def _guard_command(self, command: str, cwd: str) -> str | None: + def _guard_command( + self, + command: str, + cwd: str, + *, + restrict_to_workspace: bool | None = None, + ) -> str | None: """Best-effort safety guard for potentially destructive commands.""" cmd = command.strip() lower = cmd.lower() @@ -548,11 +574,17 @@ class ExecTool(Tool): return "Error: Command blocked by allowlist filter (not in allowlist)" from nanobot.security.network import contains_internal_url - if contains_internal_url(cmd): + if contains_internal_url( + cmd, + allow_loopback=current_scope_allows_loopback( + enabled=self.webui_allow_local_service_access, + ), + ): # The runner turns this marker into a non-retryable security hint. return "Error: Command blocked by safety guard (internal/private URL detected)" - if self.restrict_to_workspace: + should_restrict = self.restrict_to_workspace if restrict_to_workspace is None else restrict_to_workspace + if should_restrict: if "..\\" in cmd or "../" in cmd: return ( "Error: Command blocked by safety guard (path traversal detected)" @@ -577,11 +609,9 @@ class ExecTool(Tool): continue media_path = get_media_dir().resolve() - if (p.is_absolute() - and cwd_path not in p.parents - and p != cwd_path - and media_path not in p.parents - and p != media_path + if p.is_absolute() and not ( + is_path_within(p, cwd_path) + or is_path_within(p, media_path) ): return ( "Error: Command blocked by safety guard (path outside working dir)" diff --git a/nanobot/agent/tools/spawn.py b/nanobot/agent/tools/spawn.py index dd0f8c43e..420afc048 100644 --- a/nanobot/agent/tools/spawn.py +++ b/nanobot/agent/tools/spawn.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.context import ContextAware, RequestContext from nanobot.agent.tools.schema import NumberSchema, StringSchema, tool_parameters_schema +from nanobot.security.workspace_access import current_workspace_scope if TYPE_CHECKING: from nanobot.agent.subagent import SubagentManager @@ -91,4 +92,5 @@ class SpawnTool(Tool, ContextAware): session_key=self._session_key.get(), origin_message_id=self._origin_message_id.get(), temperature=temperature, + workspace_scope=current_workspace_scope(), ) diff --git a/nanobot/apps/cli/service.py b/nanobot/apps/cli/service.py index dfe7277c9..5dedfa78c 100644 --- a/nanobot/apps/cli/service.py +++ b/nanobot/apps/cli/service.py @@ -20,6 +20,7 @@ import httpx from nanobot.apps.protocol import app_manifest, compact_dict from nanobot.config.paths import get_runtime_subdir +from nanobot.security.workspace_policy import is_path_within CLI_ANYTHING_REGISTRY_URL = "https://hkuds.github.io/CLI-Anything/registry.json" CLI_ANYTHING_PUBLIC_REGISTRY_URL = "https://hkuds.github.io/CLI-Anything/public_registry.json" @@ -32,6 +33,7 @@ _MAX_ARTIFACT_REPORT = 12 _SAFE_NAME_RE = re.compile(r"[^a-z0-9_-]+") _MENTION_RE = re.compile(r"(^|[\s([{])@([a-z0-9_-]+)\b", re.IGNORECASE) _SHELL_META_CHARS = ("|", "&&", "||", ";", "$(", "`", ">", "<") +_ENDORSEMENT_WORD_RE = re.compile(r"\bofficial\s+", re.IGNORECASE) _ARTIFACT_EXTENSIONS = frozenset({ ".csv", ".drawio", @@ -362,6 +364,12 @@ def _truncate(text: str, limit: int = _MAX_TOOL_OUTPUT_CHARS) -> str: return text[:limit] + f"\n\n... truncated {omitted} characters ..." +def _catalog_description(app: dict[str, Any]) -> str: + """Return catalog copy without implying vendor endorsement.""" + description = str(app.get("description") or "") + return _ENDORSEMENT_WORD_RE.sub("", description).strip() + + class CliAppManager: """Manage CLI-Anything registry entries and local install state.""" @@ -554,7 +562,7 @@ class CliAppManager: "name": name, "display_name": app.get("display_name") or name, "category": app.get("category") or "uncategorized", - "description": app.get("description") or "", + "description": _catalog_description(app), "requires": app.get("requires") or "", "source": app.get("_source") or "harness", "entry_point": entry_point, @@ -630,7 +638,7 @@ class CliAppManager: app_id=name, display_name=str(app.get("display_name") or name), version=str(app.get("version") or ""), - description=str(app.get("description") or ""), + description=_catalog_description(app), category=str(app.get("category") or "uncategorized"), source=f"cli-anything:{app.get('_source') or 'harness'}", logo_url=logo_url, @@ -802,7 +810,7 @@ class CliAppManager: name = str(app.get("name") or "unknown") display = str(app.get("display_name") or name) entry = str(app.get("entry_point") or f"cli-anything-{name}") - description = str(app.get("description") or f"Use {display} from nanobot.") + description = _catalog_description(app) or f"Use {display} from nanobot." return f"""--- name: {_safe_skill_name(name)} description: >- @@ -1018,7 +1026,7 @@ Use the `run_cli_app` tool with `name="{name}"` for command execution. Do not in cwd = Path(working_dir).expanduser() if working_dir else self.workspace cwd = cwd.resolve(strict=False) workspace = self.workspace.resolve(strict=False) - if restrict_to_workspace and cwd != workspace and not cwd.is_relative_to(workspace): + if restrict_to_workspace and not is_path_within(cwd, workspace): raise CliAppError("working_dir is outside the configured workspace") return cwd diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index a1ac15495..2ccb31089 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -57,11 +57,17 @@ class ChannelManager: *, session_manager: "SessionManager | None" = None, webui_runtime_model_name: Callable[[], str | None] | None = None, + webui_static_dist: bool = True, + webui_runtime_surface: str = "browser", + webui_runtime_capabilities: dict[str, Any] | None = None, ): self.config = config self.bus = bus self._session_manager = session_manager self._webui_runtime_model_name = webui_runtime_model_name + self._webui_static_dist = webui_static_dist + self._webui_runtime_surface = webui_runtime_surface + self._webui_runtime_capabilities = dict(webui_runtime_capabilities or {}) self.channels: dict[str, BaseChannel] = {} self._dispatch_task: asyncio.Task | None = None self._origin_reply_fingerprints: dict[tuple[str, str, str], str] = {} @@ -107,12 +113,15 @@ class ChannelManager: if cls.name == "websocket": if self._session_manager is not None: kwargs["session_manager"] = self._session_manager - static_path = _default_webui_dist() + static_path = _default_webui_dist() if self._webui_static_dist else None if static_path is not None: kwargs["static_dist_path"] = static_path kwargs["workspace_path"] = self.config.workspace_path + kwargs["restrict_to_workspace"] = self.config.tools.restrict_to_workspace if self._webui_runtime_model_name is not None: kwargs["runtime_model_name"] = self._webui_runtime_model_name + kwargs["runtime_surface"] = self._webui_runtime_surface + kwargs["runtime_capabilities_overrides"] = self._webui_runtime_capabilities channel = cls(section, self.bus, **kwargs) channel.transcription_provider = transcription_provider channel.transcription_api_key = transcription_key diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index a11be1e1c..41f7c8e4e 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -11,6 +11,8 @@ from typing import Any, Literal, TypeAlias from pydantic import Field +from nanobot.security.workspace_policy import is_path_within + try: import nh3 from mistune import create_markdown @@ -344,11 +346,7 @@ class MatrixChannel(BaseChannel): """Check path is inside workspace (when restriction enabled).""" if not self._restrict_to_workspace or not self._workspace: return True - try: - path.resolve(strict=False).relative_to(self._workspace) - return True - except ValueError: - return False + return is_path_within(path, self._workspace) def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]: """Deduplicate and resolve outbound attachment paths.""" diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 87b5fbd3a..dc23b93f6 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -18,19 +18,24 @@ import ssl import time import uuid from collections.abc import Callable +from contextlib import suppress from pathlib import Path from typing import TYPE_CHECKING, Any, Self from urllib.parse import parse_qs, unquote, urlparse from loguru import logger from pydantic import Field, field_validator, model_validator -from websockets.asyncio.server import ServerConnection, serve +from websockets.asyncio.server import ServerConnection, serve, unix_serve from websockets.datastructures import Headers from websockets.exceptions import ConnectionClosed from websockets.http11 import Request as WsRequest from websockets.http11 import Response from nanobot.agent.tools.mcp import request_mcp_reload +from nanobot.security.workspace_access import ( + WORKSPACE_SCOPE_METADATA_KEY, + WorkspaceScopeError, +) from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel @@ -48,9 +53,15 @@ from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_c from nanobot.webui.settings_api import ( WebUISettingsError, create_model_configuration, + decorate_settings_payload, + login_oauth_provider, + logout_oauth_provider, + runtime_capabilities, settings_payload, update_agent_settings, update_image_generation_settings, + update_model_configuration, + update_network_safety_settings, update_provider_settings, update_web_search_settings, ) @@ -73,6 +84,9 @@ from nanobot.webui.transcript import ( build_webui_thread_response, rewrite_local_markdown_images, ) +from nanobot.webui.workspaces import ( + WebUIWorkspaceController, +) _MCP_PRESET_ACTIONS_BY_PATH = { "/api/settings/mcp-presets/enable": "enable", @@ -100,6 +114,41 @@ def _normalize_config_path(path: str) -> str: return _strip_trailing_slash(path) +def _case_insensitive_header(headers: Any, key: str) -> str: + """Read a header from websockets/http test stubs without assuming casing.""" + try: + value = headers.get(key) + except Exception: + value = None + if value is None: + try: + value = headers.get(key.lower()) + except Exception: + value = None + return str(value or "").strip() + + +def _safe_host_header(value: str) -> str: + """Return a safe Host header value, or empty when it should not be echoed.""" + value = value.strip() + if not value: + return "" + if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value): + return value + if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value): + return value + return "" + + +def _host_for_url(host: str, port: int) -> str: + host = host.strip() + if host in ("0.0.0.0", "::"): + host = "127.0.0.1" + if ":" in host and not host.startswith("["): + host = f"[{host}]" + return f"{host}:{port}" + + class WebSocketConfig(Base): """WebSocket server channel configuration. @@ -123,6 +172,7 @@ class WebSocketConfig(Base): enabled: bool = False host: str = "127.0.0.1" port: int = 8765 + unix_socket_path: str = "" path: str = "/" token: str = "" token_issue_path: str = "" @@ -141,6 +191,19 @@ class WebSocketConfig(Base): ssl_certfile: str = "" ssl_keyfile: str = "" + @field_validator("unix_socket_path") + @classmethod + def unix_socket_path_format(cls, value: str) -> str: + value = value.strip() + if not value: + return "" + if "\x00" in value: + raise ValueError("unix_socket_path must not contain NUL bytes") + path = Path(value).expanduser() + if not path.is_absolute(): + raise ValueError("unix_socket_path must be an absolute path") + return str(path) + @field_validator("path") @classmethod def path_must_start_with_slash(cls, value: str) -> str: @@ -503,7 +566,10 @@ class WebSocketChannel(BaseChannel): session_manager: "SessionManager | None" = None, static_dist_path: Path | None = None, workspace_path: Path | None = None, + restrict_to_workspace: bool = False, runtime_model_name: Callable[[], str | None] | None = None, + runtime_surface: str = "browser", + runtime_capabilities_overrides: dict[str, Any] | None = None, ): if isinstance(config, dict): config = WebSocketConfig.model_validate(config) @@ -530,7 +596,20 @@ class WebSocketChannel(BaseChannel): if workspace_path is not None else get_workspace_path() ).resolve(strict=False) + self._default_restrict_to_workspace = restrict_to_workspace + self._webui_workspaces = WebUIWorkspaceController( + session_manager=self._session_manager, + default_workspace=self._workspace_path, + default_restrict_to_workspace=self._default_restrict_to_workspace, + ) self._runtime_model_name = runtime_model_name + self._runtime_surface = ( + "native" if runtime_surface in {"native", "desktop"} else "browser" + ) + self._runtime_capabilities = runtime_capabilities( + self._runtime_surface, + runtime_capabilities_overrides, + ) self._settings_restart_sections: set[str] = set() self._stream_text_buffers: dict[tuple[str, str], list[str]] = {} # Process-local secret used to HMAC-sign media URLs. The signed URL is @@ -695,6 +774,9 @@ class WebSocketChannel(BaseChannel): if got == "/api/commands": return self._handle_commands(request) + if got == "/api/workspaces": + return self._handle_workspaces(connection, request) + if got == "/api/webui/sidebar-state": return self._handle_webui_sidebar_state(request) @@ -707,15 +789,27 @@ class WebSocketChannel(BaseChannel): if got == "/api/settings/model-configurations/create": return self._handle_settings_model_configuration_create(request) + if got == "/api/settings/model-configurations/update": + return self._handle_settings_model_configuration_update(request) + if got == "/api/settings/provider/update": return self._handle_settings_provider_update(request) + if got == "/api/settings/provider/oauth-login": + return await self._handle_settings_provider_oauth(request, "login") + + if got == "/api/settings/provider/oauth-logout": + return await self._handle_settings_provider_oauth(request, "logout") + if got == "/api/settings/web-search/update": return self._handle_settings_web_search_update(request) if got == "/api/settings/image-generation/update": return self._handle_settings_image_generation_update(request) + if got == "/api/settings/network-safety/update": + return self._handle_settings_network_safety_update(request) + if got == "/api/settings/cli-apps": return self._handle_settings_cli_apps(request) @@ -773,6 +867,12 @@ class WebSocketChannel(BaseChannel): return connection.respond(403, "Forbidden") return self._authorize_websocket_handshake(connection, query) + # API clients should never receive the SPA shell for an unknown route. + # Returning HTML here makes the WebUI fail with "Unexpected token <" + # when a dev server is pointed at an older gateway. + if got.startswith("/api/"): + return _http_error(404, "API route not found") + # 5. Static SPA serving (only if a build directory was wired in). if self._static_dist_path is not None: response = self._serve_static(got) @@ -832,15 +932,32 @@ class WebSocketChannel(BaseChannel): # while the REST surface keeps validating the other until TTL expiry. self._issued_tokens[token] = expiry self._api_tokens[token] = expiry + ws_url = self._bootstrap_ws_url(request) return _http_json_response( { "token": token, "ws_path": self._expected_path(), + "ws_url": ws_url, "expires_in": self.config.token_ttl_s, "model_name": _resolve_bootstrap_model_name(self._runtime_model_name), + "runtime_surface": self._runtime_surface, + "runtime_capabilities": self._runtime_capabilities, } ) + def _bootstrap_ws_url(self, request: Any) -> str: + """Absolute WS URL clients should prefer over a dev-server proxy.""" + headers = getattr(request, "headers", {}) or {} + host = _safe_host_header(_case_insensitive_header(headers, "Host")) + if not host: + host = _host_for_url(self.config.host, self.config.port) + + proto = _case_insensitive_header(headers, "X-Forwarded-Proto") + proto = proto.split(",", 1)[0].strip().lower() + secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip()) + scheme = "wss" if secure else "ws" + return f"{scheme}://{host}{self._expected_path()}" + def _handle_sessions_list(self, request: WsRequest) -> Response: if not self._check_api_token(request): return _http_error(401, "Unauthorized") @@ -859,13 +976,29 @@ class WebSocketChannel(BaseChannel): started_at = websocket_turn_wall_started_at(chat_id) if started_at is not None: row["run_started_at"] = started_at + scope = self._webui_workspaces.scope_for_session_key(key) + row["workspace_scope"] = scope.payload() cleaned.append(row) return _http_json_response({"sessions": cleaned}) + def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response: + if not self._check_api_token(request): + return _http_error(401, "Unauthorized") + return _http_json_response( + self._webui_workspaces.payload(controls_available=_is_localhost(connection)) + ) + def _handle_settings(self, request: WsRequest) -> Response: if not self._check_api_token(request): return _http_error(401, "Unauthorized") - return _http_json_response(self._with_settings_restart_state(settings_payload())) + return _http_json_response( + self._with_settings_restart_state( + settings_payload( + surface=self._runtime_surface, + runtime_capability_overrides=self._runtime_capabilities, + ) + ) + ) def _with_settings_restart_state( self, @@ -876,14 +1009,16 @@ class WebSocketChannel(BaseChannel): """Keep restart-required state alive for this gateway process.""" if section and payload.get("requires_restart"): self._settings_restart_sections.add(section) - if self._settings_restart_sections: - payload = dict(payload) + sections = sorted(self._settings_restart_sections) + payload = dict(payload) + if sections: payload["requires_restart"] = True - payload["restart_required_sections"] = sorted(self._settings_restart_sections) - else: - payload = dict(payload) - payload["restart_required_sections"] = [] - return payload + return decorate_settings_payload( + payload, + surface=self._runtime_surface, + runtime_capability_overrides=self._runtime_capabilities, + restart_required_sections=sections, + ) def _handle_commands(self, request: WsRequest) -> Response: if not self._check_api_token(request): @@ -939,6 +1074,16 @@ class WebSocketChannel(BaseChannel): return _http_error(e.status, e.message) return _http_json_response(self._with_settings_restart_state(payload)) + def _handle_settings_model_configuration_update(self, request: WsRequest) -> Response: + if not self._check_api_token(request): + return _http_error(401, "Unauthorized") + query = _parse_query(request.path) + try: + payload = update_model_configuration(query) + except WebUISettingsError as e: + return _http_error(e.status, e.message) + return _http_json_response(self._with_settings_restart_state(payload)) + def _handle_settings_provider_update(self, request: WsRequest) -> Response: if not self._check_api_token(request): return _http_error(401, "Unauthorized") @@ -949,6 +1094,19 @@ class WebSocketChannel(BaseChannel): return _http_error(e.status, e.message) return _http_json_response(self._with_settings_restart_state(payload, section="image")) + async def _handle_settings_provider_oauth(self, request: WsRequest, action: str) -> Response: + if not self._check_api_token(request): + return _http_error(401, "Unauthorized") + query = _parse_query(request.path) + try: + if action == "login": + payload = await asyncio.to_thread(login_oauth_provider, query) + else: + payload = await asyncio.to_thread(logout_oauth_provider, query) + except WebUISettingsError as e: + return _http_error(e.status, e.message) + return _http_json_response(self._with_settings_restart_state(payload)) + def _handle_settings_web_search_update(self, request: WsRequest) -> Response: if not self._check_api_token(request): return _http_error(401, "Unauthorized") @@ -957,7 +1115,7 @@ class WebSocketChannel(BaseChannel): payload = update_web_search_settings(query) except WebUISettingsError as e: return _http_error(e.status, e.message) - return _http_json_response(self._with_settings_restart_state(payload, section="web")) + return _http_json_response(self._with_settings_restart_state(payload, section="browser")) def _handle_settings_image_generation_update(self, request: WsRequest) -> Response: if not self._check_api_token(request): @@ -969,6 +1127,16 @@ class WebSocketChannel(BaseChannel): return _http_error(e.status, e.message) return _http_json_response(self._with_settings_restart_state(payload, section="image")) + def _handle_settings_network_safety_update(self, request: WsRequest) -> Response: + if not self._check_api_token(request): + return _http_error(401, "Unauthorized") + query = _parse_query(request.path) + try: + payload = update_network_safety_settings(query) + except WebUISettingsError as e: + return _http_error(e.status, e.message) + return _http_json_response(self._with_settings_restart_state(payload, section="runtime")) + def _handle_settings_cli_apps(self, request: WsRequest) -> Response: if not self._check_api_token(request): return _http_error(401, "Unauthorized") @@ -1058,13 +1226,19 @@ class WebSocketChannel(BaseChannel): return _http_error(400, "invalid session key") if not self._is_websocket_channel_session_key(decoded_key): return _http_error(404, "session not found") + scope = self._webui_workspaces.scope_for_session_key(decoded_key) data = build_webui_thread_response( decoded_key, augment_user_media=self._augment_transcript_user_media, - augment_assistant_text=self._rewrite_local_markdown_images, + augment_assistant_text=lambda text: rewrite_local_markdown_images( + text, + workspace_path=scope.project_path, + sign_path=self._sign_or_stage_media_path, + ), ) if data is None: return _http_error(404, "webui thread not found") + data["workspace_scope"] = scope.payload() return _http_json_response(data) def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None: @@ -1359,34 +1533,63 @@ class WebSocketChannel(BaseChannel): await self._connection_loop(connection) self.logger.info( - "WebSocket server listening on {}://{}:{}{}", - scheme, - self.config.host, - self.config.port, - self.config.path, + "WebSocket server listening on {}", + ( + f"unix:{self.config.unix_socket_path}{self.config.path}" + if self.config.unix_socket_path + else f"{scheme}://{self.config.host}:{self.config.port}{self.config.path}" + ), ) if self.config.token_issue_path: self.logger.info( - "WebSocket token issue route: {}://{}:{}{}", - scheme, - self.config.host, - self.config.port, - _normalize_config_path(self.config.token_issue_path), + "WebSocket token issue route: {}", + ( + f"unix:{self.config.unix_socket_path}{_normalize_config_path(self.config.token_issue_path)}" + if self.config.unix_socket_path + else ( + f"{scheme}://{self.config.host}:{self.config.port}" + f"{_normalize_config_path(self.config.token_issue_path)}" + ) + ), ) async def runner() -> None: - async with serve( - handler, - self.config.host, - self.config.port, - process_request=process_request, - max_size=self.config.max_message_bytes, - ping_interval=self.config.ping_interval_s, - ping_timeout=self.config.ping_timeout_s, - ssl=ssl_context, - ): + socket_path = self.config.unix_socket_path + if socket_path: + path_obj = Path(socket_path) + path_obj.parent.mkdir(parents=True, exist_ok=True) + with suppress(FileNotFoundError): + path_obj.unlink() + server = await unix_serve( + handler, + socket_path, + process_request=process_request, + max_size=self.config.max_message_bytes, + ping_interval=self.config.ping_interval_s, + ping_timeout=self.config.ping_timeout_s, + ) + with suppress(OSError): + path_obj.chmod(0o600) + else: + server = await serve( + handler, + self.config.host, + self.config.port, + process_request=process_request, + max_size=self.config.max_message_bytes, + ping_interval=self.config.ping_interval_s, + ping_timeout=self.config.ping_timeout_s, + ssl=ssl_context, + ) + try: assert self._stop_event is not None await self._stop_event.wait() + finally: + server.close() + await server.wait_closed() + if socket_path: + with suppress(FileNotFoundError): + Path(socket_path).unlink() self._server_task = asyncio.create_task(runner()) await self._server_task @@ -1530,8 +1733,25 @@ class WebSocketChannel(BaseChannel): t = envelope.get("type") if t == "new_chat": new_id = str(uuid.uuid4()) + scope = await self._workspace_scope_or_error( + connection, + lambda: self._webui_workspaces.scope_for_new_chat( + envelope, + controls_available=_is_localhost(connection), + ), + ) + if scope is None: + return + self._webui_workspaces.persist_scope(new_id, scope) self._attach(connection, new_id) await self._send_event(connection, "attached", chat_id=new_id) + await self._send_event( + connection, + "session_updated", + chat_id=new_id, + scope="metadata", + workspace_scope=scope.payload(), + ) await self._hydrate_after_subscribe(new_id) return if t == "attach": @@ -1543,6 +1763,32 @@ class WebSocketChannel(BaseChannel): await self._send_event(connection, "attached", chat_id=cid) await self._hydrate_after_subscribe(cid) return + if t == "set_workspace_scope": + cid = envelope.get("chat_id") + if not _is_valid_chat_id(cid): + await self._send_event(connection, "error", detail="invalid chat_id") + return + scope = await self._workspace_scope_or_error( + connection, + lambda: self._webui_workspaces.scope_for_set_request( + envelope, + chat_id=cid, + chat_running=websocket_turn_wall_started_at(cid) is not None, + controls_available=_is_localhost(connection), + ), + chat_id=cid, + ) + if scope is None: + return + self._webui_workspaces.persist_scope(cid, scope) + await self._send_event( + connection, + "session_updated", + chat_id=cid, + scope="metadata", + workspace_scope=scope.payload(), + ) + return if t == "message": cid = envelope.get("chat_id") content = envelope.get("content") @@ -1574,6 +1820,18 @@ class WebSocketChannel(BaseChannel): if not content.strip() and not media_paths: await self._send_event(connection, "error", detail="missing content") return + scope = await self._workspace_scope_or_error( + connection, + lambda: self._webui_workspaces.scope_for_message( + envelope, + chat_id=cid, + chat_running=websocket_turn_wall_started_at(cid) is not None, + controls_available=_is_localhost(connection), + ), + chat_id=cid, + ) + if scope is None: + return # Auto-attach on first use so clients can one-shot without a separate attach. self._attach(connection, cid) @@ -1587,6 +1845,8 @@ class WebSocketChannel(BaseChannel): mcp_presets = normalize_mcp_preset_mentions(envelope.get("mcp_presets")) if mcp_presets: metadata["mcp_presets"] = mcp_presets + metadata[WORKSPACE_SCOPE_METADATA_KEY] = scope.metadata() + self._webui_workspaces.persist_scope(cid, scope) image_generation = envelope.get("image_generation") if isinstance(image_generation, dict) and image_generation.get("enabled") is True: aspect_ratio = image_generation.get("aspect_ratio") @@ -1605,6 +1865,25 @@ class WebSocketChannel(BaseChannel): return await self._send_event(connection, "error", detail=f"unknown type: {t!r}") + async def _workspace_scope_or_error( + self, + connection: Any, + resolver: Callable[[], Any], + *, + chat_id: str | None = None, + ) -> Any | None: + try: + return resolver() + except WorkspaceScopeError as exc: + await self._send_event( + connection, + "error", + detail="workspace_scope_rejected", + reason=exc.message, + **({"chat_id": chat_id} if chat_id else {}), + ) + return None + async def stop(self) -> None: if not self._running: return diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 1c0cbbdfc..acebdabfd 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -720,11 +720,144 @@ def gateway( _run_gateway(cfg, port=port) +def _load_or_create_desktop_config(config: str | None, workspace: str | None) -> Config: + """Load the desktop-owned config, creating it on first launch.""" + from nanobot.config.loader import ( + get_config_path, + load_config, + resolve_config_env_vars, + save_config, + set_config_path, + ) + from nanobot.config.schema import Config as NanobotConfig + + config_path = Path(config).expanduser().resolve() if config else get_config_path() + set_config_path(config_path) + created = False + if config_path.exists(): + try: + loaded = resolve_config_env_vars(load_config(config_path)) + except ValueError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) + else: + loaded = NanobotConfig() + created = True + + if workspace: + workspace_path = Path(workspace).expanduser() + loaded.agents.defaults.workspace = str(workspace_path) + created = True + + if created: + save_config(loaded, config_path) + return loaded + + +def _configure_desktop_gateway( + config: Config, + *, + webui_port: int, + webui_socket: str | None, + token_issue_secret: str, +) -> None: + """Force a local WebSocket-only gateway for the desktop app process.""" + config.gateway.host = "127.0.0.1" + config.gateway.port = webui_port + config.gateway.heartbeat.enabled = False + + extras = dict(getattr(config.channels, "__pydantic_extra__", None) or {}) + for name, section in list(extras.items()): + if name == "websocket": + continue + if isinstance(section, dict): + extras[name] = {**section, "enabled": False} + else: + with suppress(Exception): + setattr(section, "enabled", False) + extras[name] = section + + websocket_cfg = extras.get("websocket") + if not isinstance(websocket_cfg, dict): + websocket_cfg = {} + websocket_cfg.update( + { + "enabled": True, + "host": "127.0.0.1", + "port": webui_port, + "unix_socket_path": webui_socket or "", + "path": "/", + "token_issue_secret": token_issue_secret, + "websocket_requires_token": True, + "allow_from": ["*"], + "streaming": True, + } + ) + extras["websocket"] = websocket_cfg + config.channels.__pydantic_extra__ = extras + + +@app.command("desktop-gateway", hidden=True) +def desktop_gateway( + webui_port: int = typer.Option(0, "--webui-port", min=0, max=65535), + webui_socket: str | None = typer.Option(None, "--webui-socket", help="Unix socket path for desktop IPC"), + token_issue_secret: str = typer.Option(..., "--token-issue-secret"), + workspace: str | None = typer.Option(None, "--workspace", "-w", help="Desktop workspace directory"), + config: str | None = typer.Option(None, "--config", "-c", help="Desktop config file"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), +): + """Start the private local gateway used by nanobot Desktop.""" + if not token_issue_secret.strip(): + console.print("[red]Error: --token-issue-secret is required[/red]") + raise typer.Exit(1) + if webui_port <= 0 and not (webui_socket or "").strip(): + console.print("[red]Error: --webui-port or --webui-socket is required[/red]") + raise typer.Exit(1) + if verbose: + logger.remove(_log_handler_id) + logger.add( + sys.stderr, + format=( + "{time:YYYY-MM-DD HH:mm:ss} | " + "{level: <5} | " + "{extra[channel]} | " + "{message}" + ), + level="DEBUG", + colorize=None, + filter=lambda record: record["extra"].setdefault("channel", "-") or True, + ) + cfg = _load_or_create_desktop_config(config, workspace) + _configure_desktop_gateway( + cfg, + webui_port=webui_port, + webui_socket=webui_socket, + token_issue_secret=token_issue_secret, + ) + _run_gateway( + cfg, + port=webui_port, + webui_static_dist=False, + webui_runtime_surface="native", + webui_runtime_capabilities={ + "can_restart_engine": True, + "can_pick_folder": True, + "can_open_logs": True, + "can_export_diagnostics": True, + }, + health_server_enabled=False, + ) + + def _run_gateway( config: Config, *, port: int | None = None, open_browser_url: str | None = None, + webui_static_dist: bool = True, + webui_runtime_surface: str = "browser", + webui_runtime_capabilities: dict[str, Any] | None = None, + health_server_enabled: bool = True, ) -> None: """Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up.""" from nanobot.agent.tools.cron import CronTool @@ -957,6 +1090,9 @@ def _run_gateway( bus, session_manager=session_manager, webui_runtime_model_name=_webui_runtime_model_name, + webui_static_dist=webui_static_dist, + webui_runtime_surface=webui_runtime_surface, + webui_runtime_capabilities=webui_runtime_capabilities, ) def _pick_heartbeat_target() -> tuple[str, str]: @@ -1088,8 +1224,9 @@ def _run_gateway( tasks = [ agent.run(), channels.start_all(), - _health_server(config.gateway.host, port), ] + if health_server_enabled: + tasks.append(_health_server(config.gateway.host, port)) if open_browser_url: tasks.append(_open_browser_when_ready()) await asyncio.gather(*tasks) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index db6650d91..edf40bfed 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -295,7 +295,16 @@ class ToolsConfig(Base): image_generation: ImageGenerationToolConfig = Field( default_factory=lambda: _lazy_default("nanobot.agent.tools.image_generation", "ImageGenerationToolConfig"), ) - restrict_to_workspace: bool = False # restrict all tool access to workspace directory + restrict_to_workspace: bool = False # policy intent: keep tool access inside workspace when possible + webui_allow_local_service_access: bool = Field( + default=True, + validation_alias=AliasChoices( + "webuiAllowLocalServiceAccess", + "webui_allow_local_service_access", + "allowLocalPreviewAccess", + "allow_local_preview_access", + ), + ) # allow WebUI Full Access shell checks against localhost services; legacy allowLocalPreviewAccess still reads mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict) ssrf_whitelist: list[str] = Field(default_factory=list) # CIDR ranges to exempt from SSRF blocking (e.g. ["100.64.0.0/10"] for Tailscale) @@ -314,6 +323,11 @@ class Config(BaseSettings): validation_alias=AliasChoices("modelPresets", "model_presets"), ) + def __init__(self, **values: Any) -> None: + if not type(self).__pydantic_complete__: + _resolve_tool_config_refs() + super().__init__(**values) + @model_validator(mode="after") def _validate_model_preset(self) -> "Config": if "default" in self.model_presets: diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index c0afdf572..fc92e8ae8 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -15,7 +15,7 @@ from oauth_cli_kit import get_token as get_codex_token from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.openai_responses import ( - consume_sse, + consume_sse_with_reasoning, convert_messages, convert_tools, ) @@ -41,6 +41,7 @@ class OpenAICodexProvider(LLMProvider): reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, + on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: """Shared request logic for both chat() and chat_stream().""" @@ -62,28 +63,36 @@ class OpenAICodexProvider(LLMProvider): "tool_choice": tool_choice or "auto", "parallel_tool_calls": True, } - if reasoning_effort and reasoning_effort.lower() != "none": - body["reasoning"] = {"effort": reasoning_effort} + reasoning_options = _build_reasoning_options(reasoning_effort) + if reasoning_options: + body["reasoning"] = reasoning_options if tools: body["tools"] = convert_tools(tools) try: try: - content, tool_calls, finish_reason = await _request_codex( + content, tool_calls, finish_reason, reasoning_content = await _request_codex( DEFAULT_CODEX_URL, headers, body, verify=True, on_content_delta=on_content_delta, + on_thinking_delta=on_thinking_delta, on_tool_call_delta=on_tool_call_delta, ) except Exception as e: if "CERTIFICATE_VERIFY_FAILED" not in str(e): raise logger.warning("SSL verification failed for Codex API; retrying with verify=False") - content, tool_calls, finish_reason = await _request_codex( + content, tool_calls, finish_reason, reasoning_content = await _request_codex( DEFAULT_CODEX_URL, headers, body, verify=False, on_content_delta=on_content_delta, + on_thinking_delta=on_thinking_delta, on_tool_call_delta=on_tool_call_delta, ) - return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason) + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason=finish_reason, + reasoning_content=reasoning_content, + ) except Exception as e: response = _codex_error_response(e) exc_type = "CodexHTTPError" if isinstance(e, _CodexHTTPError) else type(e).__name__ @@ -118,7 +127,6 @@ class OpenAICodexProvider(LLMProvider): on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> LLMResponse: - _ = on_thinking_delta return await self._call_codex( messages, tools, @@ -126,6 +134,7 @@ class OpenAICodexProvider(LLMProvider): reasoning_effort, tool_choice, on_content_delta, + on_thinking_delta, on_tool_call_delta, ) @@ -139,6 +148,16 @@ def _strip_model_prefix(model: str) -> str: return model +def _build_reasoning_options(reasoning_effort: str | None) -> dict[str, str] | None: + """Opt in to visible summaries without changing provider-default effort.""" + if reasoning_effort and reasoning_effort.lower() == "none": + return {"effort": "none"} + options = {"summary": "auto"} + if reasoning_effort: + options["effort"] = reasoning_effort + return options + + def _build_headers(account_id: str, token: str) -> dict[str, str]: return { "Authorization": f"Bearer {token}", @@ -176,8 +195,9 @@ async def _request_codex( body: dict[str, Any], verify: bool, on_content_delta: Callable[[str], Awaitable[None]] | None = None, + on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, -) -> tuple[str, list[ToolCallRequest], str]: +) -> tuple[str, list[ToolCallRequest], str, str | None]: idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) async with httpx.AsyncClient(timeout=idle_timeout_s, verify=verify) as client: async with client.stream("POST", url, headers=headers, json=body) as response: @@ -194,7 +214,12 @@ async def _request_codex( error_code=error_code, should_retry=_should_retry_status(response.status_code, error_type, error_code, raw), ) - return await consume_sse(response, on_content_delta, on_tool_call_delta) + return await consume_sse_with_reasoning( + response, + on_content_delta=on_content_delta, + on_tool_call_delta=on_tool_call_delta, + on_reasoning_delta=on_thinking_delta, + ) def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: diff --git a/nanobot/providers/openai_responses/__init__.py b/nanobot/providers/openai_responses/__init__.py index b40e896ed..25f19afb6 100644 --- a/nanobot/providers/openai_responses/__init__.py +++ b/nanobot/providers/openai_responses/__init__.py @@ -10,6 +10,7 @@ from nanobot.providers.openai_responses.parsing import ( FINISH_REASON_MAP, consume_sdk_stream, consume_sse, + consume_sse_with_reasoning, iter_sse, map_finish_reason, parse_response_output, @@ -22,6 +23,7 @@ __all__ = [ "split_tool_call_id", "iter_sse", "consume_sse", + "consume_sse_with_reasoning", "consume_sdk_stream", "map_finish_reason", "parse_response_output", diff --git a/nanobot/providers/openai_responses/parsing.py b/nanobot/providers/openai_responses/parsing.py index 707652d74..846165562 100644 --- a/nanobot/providers/openai_responses/parsing.py +++ b/nanobot/providers/openai_responses/parsing.py @@ -65,10 +65,28 @@ async def consume_sse( on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, ) -> tuple[str, list[ToolCallRequest], str]: """Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``.""" + content, tool_calls, finish_reason, _ = await consume_sse_with_reasoning( + response, + on_content_delta=on_content_delta, + on_tool_call_delta=on_tool_call_delta, + ) + return content, tool_calls, finish_reason + + +async def consume_sse_with_reasoning( + response: httpx.Response, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, + on_reasoning_delta: Callable[[str], Awaitable[None]] | None = None, +) -> tuple[str, list[ToolCallRequest], str, str | None]: + """Consume a Responses API SSE stream, including visible reasoning summaries.""" content = "" tool_calls: list[ToolCallRequest] = [] tool_call_buffers: dict[str, dict[str, Any]] = {} + tool_call_args_emitted: set[str] = set() finish_reason = "stop" + reasoning_content: str | None = None + streamed_reasoning = False async for event in iter_sse(response): event_type = event.get("type") @@ -94,6 +112,26 @@ async def consume_sse( content += delta_text if on_content_delta and delta_text: await on_content_delta(delta_text) + elif event_type == "response.reasoning_summary_text.delta": + delta_text = event.get("delta") or "" + if delta_text: + reasoning_content = (reasoning_content or "") + delta_text + streamed_reasoning = True + if on_reasoning_delta: + await on_reasoning_delta(delta_text) + elif event_type == "response.reasoning_summary_text.done": + text = event.get("text") or "" + if text and not streamed_reasoning and not reasoning_content: + reasoning_content = text + if on_reasoning_delta: + await on_reasoning_delta(text) + elif event_type == "response.reasoning_summary_part.done": + part = event.get("part") or {} + text = part.get("text") if part.get("type") == "summary_text" else None + if text and not streamed_reasoning and not reasoning_content: + reasoning_content = text + if on_reasoning_delta: + await on_reasoning_delta(text) elif event_type == "response.function_call_arguments.delta": call_id = event.get("call_id") if call_id and call_id in tool_call_buffers: @@ -108,7 +146,15 @@ async def consume_sse( elif event_type == "response.function_call_arguments.done": call_id = event.get("call_id") if call_id and call_id in tool_call_buffers: - tool_call_buffers[call_id]["arguments"] = event.get("arguments") or "" + arguments = event.get("arguments") or "" + tool_call_buffers[call_id]["arguments"] = arguments + if on_tool_call_delta: + tool_call_args_emitted.add(str(call_id)) + await on_tool_call_delta({ + "call_id": str(call_id), + "name": str(tool_call_buffers[call_id].get("name") or ""), + "arguments": str(arguments), + }) elif event_type == "response.output_item.done": item = event.get("item") or {} if item.get("type") == "function_call": @@ -117,6 +163,13 @@ async def consume_sse( continue buf = tool_call_buffers.get(call_id) or {} args_raw = buf.get("arguments") or item.get("arguments") or "{}" + if on_tool_call_delta and str(call_id) not in tool_call_args_emitted: + tool_call_args_emitted.add(str(call_id)) + await on_tool_call_delta({ + "call_id": str(call_id), + "name": str(buf.get("name") or item.get("name") or ""), + "arguments": str(args_raw), + }) try: args = json.loads(args_raw) except Exception: @@ -135,14 +188,44 @@ async def consume_sse( arguments=args, ) ) + elif item.get("type") == "reasoning" and not reasoning_content: + summary = _extract_reasoning_summary_from_output([item]) + if summary: + reasoning_content = summary + if on_reasoning_delta: + await on_reasoning_delta(summary) elif event_type == "response.completed": - status = (event.get("response") or {}).get("status") + response_obj = event.get("response") or {} + status = response_obj.get("status") finish_reason = map_finish_reason(status) + if not reasoning_content: + summary = _extract_reasoning_summary_from_output(response_obj.get("output") or []) + if summary: + reasoning_content = summary + if on_reasoning_delta: + await on_reasoning_delta(summary) elif event_type in {"error", "response.failed"}: detail = event.get("error") or event.get("message") or event raise RuntimeError(f"Response failed: {str(detail)[:500]}") - return content, tool_calls, finish_reason + return content, tool_calls, finish_reason, reasoning_content + + +def _extract_reasoning_summary_from_output(output: Any) -> str | None: + parts: list[str] = [] + for item in output or []: + if not isinstance(item, dict): + dump = getattr(item, "model_dump", None) + item = dump() if callable(dump) else vars(item) + if item.get("type") != "reasoning": + continue + for summary in item.get("summary") or []: + if not isinstance(summary, dict): + dump = getattr(summary, "model_dump", None) + summary = dump() if callable(dump) else vars(summary) + if summary.get("type") == "summary_text" and summary.get("text"): + parts.append(summary["text"]) + return "".join(parts) or None def parse_response_output(response: Any) -> LLMResponse: @@ -230,6 +313,7 @@ async def consume_sdk_stream( content = "" tool_calls: list[ToolCallRequest] = [] tool_call_buffers: dict[str, dict[str, Any]] = {} + tool_call_args_emitted: set[str] = set() finish_reason = "stop" usage: dict[str, int] = {} reasoning_content: str | None = None @@ -272,7 +356,15 @@ async def consume_sdk_stream( elif event_type == "response.function_call_arguments.done": call_id = getattr(event, "call_id", None) if call_id and call_id in tool_call_buffers: - tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or "" + arguments = getattr(event, "arguments", "") or "" + tool_call_buffers[call_id]["arguments"] = arguments + if on_tool_call_delta: + tool_call_args_emitted.add(str(call_id)) + await on_tool_call_delta({ + "call_id": str(call_id), + "name": str(tool_call_buffers[call_id].get("name") or ""), + "arguments": str(arguments), + }) elif event_type == "response.output_item.done": item = getattr(event, "item", None) if item and getattr(item, "type", None) == "function_call": @@ -281,6 +373,13 @@ async def consume_sdk_stream( continue buf = tool_call_buffers.get(call_id) or {} args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}" + if on_tool_call_delta and str(call_id) not in tool_call_args_emitted: + tool_call_args_emitted.add(str(call_id)) + await on_tool_call_delta({ + "call_id": str(call_id), + "name": str(buf.get("name") or getattr(item, "name", None) or ""), + "arguments": str(args_raw), + }) try: args = json.loads(args_raw) except Exception: diff --git a/nanobot/security/network.py b/nanobot/security/network.py index 54676b5d9..dfd3e9e47 100644 --- a/nanobot/security/network.py +++ b/nanobot/security/network.py @@ -42,9 +42,14 @@ def _is_private(addr: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: return any(addr in net for net in _BLOCKED_NETWORKS) -def validate_url_target(url: str) -> tuple[bool, str]: +def validate_url_target(url: str, *, allow_loopback: bool = False) -> tuple[bool, str]: """Validate a URL is safe to fetch: scheme, hostname, and resolved IPs. + ``allow_loopback`` is intentionally narrow: it only permits literal + loopback hosts (localhost, 127.0.0.0/8, ::1) when every resolved address is + loopback. It does not allow RFC1918, link-local, metadata, or public DNS + names that happen to resolve to loopback. + Returns (ok, error_message). When ok is True, error_message is empty. """ try: @@ -66,11 +71,16 @@ def validate_url_target(url: str) -> tuple[bool, str]: except socket.gaierror: return False, f"Cannot resolve hostname: {hostname}" + addrs: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = [] for info in infos: try: addr = ipaddress.ip_address(info[4][0]) except ValueError: continue + addrs.append(addr) + if allow_loopback and _is_allowed_loopback_target(hostname, addrs): + return True, "" + for addr in addrs: if _is_private(addr): return False, f"Blocked: {hostname} resolves to private/internal address {addr}" @@ -109,11 +119,25 @@ def validate_resolved_url(url: str) -> tuple[bool, str]: return True, "" -def contains_internal_url(command: str) -> bool: +def contains_internal_url(command: str, *, allow_loopback: bool = False) -> bool: """Return True if the command string contains a URL targeting an internal/private address.""" for m in _URL_RE.finditer(command): url = m.group(0) - ok, _ = validate_url_target(url) + ok, _ = validate_url_target(url, allow_loopback=allow_loopback) if not ok: return True return False + + +def _is_allowed_loopback_target( + hostname: str, + addrs: list[ipaddress.IPv4Address | ipaddress.IPv6Address], +) -> bool: + if not addrs or not all(addr.is_loopback for addr in addrs): + return False + normalized = hostname.rstrip(".").lower() + if normalized == "localhost": + return True + with suppress(ValueError): + return ipaddress.ip_address(hostname).is_loopback + return False diff --git a/nanobot/security/workspace_access.py b/nanobot/security/workspace_access.py new file mode 100644 index 000000000..59c54559d --- /dev/null +++ b/nanobot/security/workspace_access.py @@ -0,0 +1,430 @@ +"""Workspace access scope and sandbox capability helpers.""" + +from __future__ import annotations + +import os +from contextvars import ContextVar, Token +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +WorkspaceAccessMode = Literal["restricted", "full"] +WORKSPACE_SCOPE_METADATA_KEY = "workspace_scope" +_ACCESS_MODES = {"restricted", "full"} + +_TRUE_VALUES = {"1", "true", "yes", "on", "enabled"} +_FALSE_VALUES = {"0", "false", "no", "off", "disabled", ""} +_PROVIDER_LABELS = { + "none": "None", + "unknown": "Unknown system sandbox", + "macos_app_sandbox": "macOS App Sandbox", + "bwrap": "Bubblewrap", +} + +_CURRENT_WORKSPACE_SCOPE: ContextVar["WorkspaceScope | None"] = ContextVar( + "nanobot_workspace_scope", + default=None, +) + + +class WorkspaceScopeError(ValueError): + """Raised when a requested WebUI workspace scope is invalid.""" + + status = 400 + + def __init__(self, message: str, *, status: int = 400) -> None: + super().__init__(message) + self.message = message + self.status = status + + +@dataclass(frozen=True) +class WorkspaceSandboxStatus: + """Resolved workspace sandbox state for runtime display and tooling.""" + + restrict_to_workspace: bool + workspace_root: str + level: str + enforced: bool + provider: str + provider_label: str + summary: str + + def as_dict(self) -> dict[str, object]: + return { + "restrict_to_workspace": self.restrict_to_workspace, + "workspace_root": self.workspace_root, + "level": self.level, + "enforced": self.enforced, + "provider": self.provider, + "provider_label": self.provider_label, + "summary": self.summary, + } + + +@dataclass(frozen=True) +class WorkspaceScope: + """Effective project root and access mode for one agent turn.""" + + project_path: Path + access_mode: WorkspaceAccessMode + restrict_to_workspace: bool + sandbox_status: WorkspaceSandboxStatus + source_channel: str | None = None + + @property + def project_name(self) -> str: + return self.project_path.name or str(self.project_path) + + def metadata(self) -> dict[str, str]: + return { + "project_path": str(self.project_path), + "access_mode": self.access_mode, + } + + def payload(self) -> dict[str, Any]: + return { + **self.metadata(), + "project_name": self.project_name, + "restrict_to_workspace": self.restrict_to_workspace, + "sandbox_status": self.sandbox_status.as_dict(), + } + + +@dataclass(frozen=True) +class ToolWorkspace: + """Workspace policy resolved for a tool call.""" + + project_path: Path | None + restrict_to_workspace: bool + scope: WorkspaceScope | None = None + + @property + def allowed_root(self) -> Path | None: + if self.restrict_to_workspace and self.project_path is not None: + return self.project_path + return None + + +@dataclass(frozen=True) +class WorkspaceScopeResolver: + """Resolve the effective workspace scope at an agent turn boundary.""" + + default_workspace: str | Path + default_restrict_to_workspace: bool + scoped_channel: str = "websocket" + + @property + def sandbox_status(self) -> WorkspaceSandboxStatus: + return self.default().sandbox_status + + def default(self) -> WorkspaceScope: + return default_workspace_scope( + self.default_workspace, + self.default_restrict_to_workspace, + ) + + def for_message( + self, + msg: Any, + session_metadata: Any, + ) -> WorkspaceScope: + return self.for_turn( + channel=getattr(msg, "channel", None), + message_metadata=getattr(msg, "metadata", None), + session_metadata=session_metadata, + ) + + def for_turn( + self, + *, + channel: str | None, + message_metadata: Any, + session_metadata: Any, + ) -> WorkspaceScope: + if channel != self.scoped_channel: + return self.default() + return resolve_effective_workspace_scope( + message_metadata=message_metadata, + session_metadata=session_metadata, + default_workspace=self.default_workspace, + default_restrict_to_workspace=self.default_restrict_to_workspace, + source_channel=channel, + ) + + def persist_message_scope(self, session: Any, msg: Any) -> None: + if getattr(msg, "channel", None) != self.scoped_channel: + return + metadata = getattr(msg, "metadata", None) + if not isinstance(metadata, dict): + return + raw = metadata.get(WORKSPACE_SCOPE_METADATA_KEY) + if isinstance(raw, dict): + session.metadata[WORKSPACE_SCOPE_METADATA_KEY] = dict(raw) + + +def workspace_sandbox_status( + *, + restrict_to_workspace: bool, + workspace: str | Path, + environ: dict[str, str] | None = None, +) -> WorkspaceSandboxStatus: + """Return how workspace restriction is enforced in the current host.""" + + workspace_root = str(Path(workspace).expanduser().resolve(strict=False)) + provider = _env_system_provider(environ) + if not restrict_to_workspace: + return WorkspaceSandboxStatus( + restrict_to_workspace=False, + workspace_root=workspace_root, + level="off", + enforced=False, + provider="none", + provider_label=_provider_label("none"), + summary="Workspace restriction is disabled.", + ) + + if provider: + label = _provider_label(provider) + return WorkspaceSandboxStatus( + restrict_to_workspace=True, + workspace_root=workspace_root, + level="system", + enforced=True, + provider=provider, + provider_label=label, + summary=f"Workspace restriction is system-enforced by {label}.", + ) + + return WorkspaceSandboxStatus( + restrict_to_workspace=True, + workspace_root=workspace_root, + level="application", + enforced=False, + provider="none", + provider_label=_provider_label("none"), + summary="Workspace restriction uses nanobot application-level guards.", + ) + + +def default_access_mode(restrict_to_workspace: bool) -> WorkspaceAccessMode: + return "restricted" if restrict_to_workspace else "full" + + +def build_workspace_scope( + project_path: str | Path, + access_mode: str, + *, + source_channel: str | None = None, +) -> WorkspaceScope: + mode = _normalize_access_mode(access_mode) + root = Path(project_path).expanduser().resolve(strict=False) + restrict = mode == "restricted" + return WorkspaceScope( + project_path=root, + access_mode=mode, + restrict_to_workspace=restrict, + sandbox_status=workspace_sandbox_status( + restrict_to_workspace=restrict, + workspace=root, + ), + source_channel=source_channel, + ) + + +def default_workspace_scope( + workspace: str | Path, + restrict_to_workspace: bool, + *, + source_channel: str | None = None, +) -> WorkspaceScope: + return build_workspace_scope( + workspace, + default_access_mode(restrict_to_workspace), + source_channel=source_channel, + ) + + +def validate_workspace_scope_payload( + raw: Any, + *, + default_workspace: str | Path, + default_restrict_to_workspace: bool, + source_channel: str | None = None, +) -> WorkspaceScope: + """Validate a client-requested workspace scope.""" + if raw is None: + return default_workspace_scope( + default_workspace, + default_restrict_to_workspace, + source_channel=source_channel, + ) + if not isinstance(raw, dict): + raise WorkspaceScopeError("workspace_scope must be an object") + + raw_path = raw.get("project_path") or raw.get("path") + if raw_path is None or raw_path == "": + raw_path = str(Path(default_workspace).expanduser().resolve(strict=False)) + if not isinstance(raw_path, str): + raise WorkspaceScopeError("project_path must be a string") + if "\0" in raw_path: + raise WorkspaceScopeError("project_path contains invalid characters") + + project = Path(raw_path).expanduser() + if not project.is_absolute(): + raise WorkspaceScopeError("project_path must be absolute") + project = project.resolve(strict=False) + if not project.is_dir(): + raise WorkspaceScopeError("project_path must be an existing directory") + + raw_mode = raw.get("access_mode") + if raw_mode is None: + raw_mode = default_access_mode(default_restrict_to_workspace) + if not isinstance(raw_mode, str): + raise WorkspaceScopeError("access_mode must be a string") + return build_workspace_scope(project, raw_mode, source_channel=source_channel) + + +def workspace_scope_from_metadata( + metadata: Any, + *, + default_workspace: str | Path, + default_restrict_to_workspace: bool, + source_channel: str | None = None, +) -> WorkspaceScope: + """Resolve persisted metadata, falling back safely for old or stale sessions.""" + if not isinstance(metadata, dict): + return default_workspace_scope( + default_workspace, + default_restrict_to_workspace, + source_channel=source_channel, + ) + try: + return validate_workspace_scope_payload( + metadata.get(WORKSPACE_SCOPE_METADATA_KEY), + default_workspace=default_workspace, + default_restrict_to_workspace=default_restrict_to_workspace, + source_channel=source_channel, + ) + except WorkspaceScopeError: + return default_workspace_scope( + default_workspace, + default_restrict_to_workspace, + source_channel=source_channel, + ) + + +def resolve_effective_workspace_scope( + *, + message_metadata: Any, + session_metadata: Any, + default_workspace: str | Path, + default_restrict_to_workspace: bool, + source_channel: str | None = None, +) -> WorkspaceScope: + if isinstance(message_metadata, dict) and WORKSPACE_SCOPE_METADATA_KEY in message_metadata: + return workspace_scope_from_metadata( + message_metadata, + default_workspace=default_workspace, + default_restrict_to_workspace=default_restrict_to_workspace, + source_channel=source_channel, + ) + return workspace_scope_from_metadata( + session_metadata, + default_workspace=default_workspace, + default_restrict_to_workspace=default_restrict_to_workspace, + source_channel=source_channel, + ) + + +def bind_workspace_scope(scope: WorkspaceScope) -> Token[WorkspaceScope | None]: + return _CURRENT_WORKSPACE_SCOPE.set(scope) + + +def reset_workspace_scope(token: Token[WorkspaceScope | None]) -> None: + _CURRENT_WORKSPACE_SCOPE.reset(token) + + +def current_workspace_scope() -> WorkspaceScope | None: + return _CURRENT_WORKSPACE_SCOPE.get() + + +def current_tool_workspace( + default_workspace: str | Path | None, + *, + restrict_to_workspace: bool = False, + sandbox_restricts_workspace: bool = False, +) -> ToolWorkspace: + """Return the workspace/access policy for the current tool call.""" + + scope = current_workspace_scope() + project_path = ( + scope.project_path + if scope is not None + else Path(default_workspace).expanduser() if default_workspace is not None else None + ) + restrict = ( + scope.restrict_to_workspace + if scope is not None + else bool(restrict_to_workspace) + ) or sandbox_restricts_workspace + return ToolWorkspace( + project_path=project_path, + restrict_to_workspace=restrict, + scope=scope, + ) + + +def current_scope_allows_loopback(*, enabled: bool) -> bool: + """Return True when the current WebUI Full Access turn may touch loopback URLs.""" + + scope = current_workspace_scope() + return bool( + enabled + and scope is not None + and scope.source_channel == "websocket" + and scope.access_mode == "full" + and not scope.restrict_to_workspace + ) + + +def _env_system_provider(environ: dict[str, str] | None = None) -> str | None: + env = environ if environ is not None else os.environ + explicit_provider = env.get("NANOBOT_WORKSPACE_SANDBOX_PROVIDER") + enforced = env.get("NANOBOT_WORKSPACE_SANDBOX_ENFORCED") + compatibility = env.get("NANOBOT_SANDBOX_ENFORCED") + + marker = enforced if enforced is not None else compatibility + if marker is None: + return None + + normalized_marker = marker.strip().lower() + if normalized_marker in _FALSE_VALUES: + return None + if normalized_marker in _TRUE_VALUES: + return _normalize_provider(explicit_provider) + return _normalize_provider(marker) + + +def _normalize_provider(value: str | None) -> str: + if not value: + return "unknown" + normalized = value.strip().lower().replace("-", "_").replace(" ", "_") + return normalized or "unknown" + + +def _provider_label(provider: str) -> str: + if provider in _PROVIDER_LABELS: + return _PROVIDER_LABELS[provider] + return provider.replace("_", " ").title() + + +def _normalize_access_mode(value: str) -> WorkspaceAccessMode: + mode = value.strip().lower().replace("_", "-") + if mode == "restrict": + mode = "restricted" + if mode == "full-access": + mode = "full" + if mode not in _ACCESS_MODES: + raise WorkspaceScopeError("access_mode must be restricted or full") + return mode # type: ignore[return-value] diff --git a/nanobot/security/workspace_policy.py b/nanobot/security/workspace_policy.py new file mode 100644 index 000000000..31ebde807 --- /dev/null +++ b/nanobot/security/workspace_policy.py @@ -0,0 +1,85 @@ +"""Workspace path boundary helpers. + +These helpers are application-level guards. They make path decisions +consistent across tools, but they are not a replacement for an OS sandbox. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Iterable + +WORKSPACE_BOUNDARY_NOTE = ( + " (this is a hard policy boundary, not a transient failure; " + "do not retry with shell tricks or alternative tools, and ask " + "the user how to proceed if the resource is genuinely required)" +) + + +class WorkspaceBoundaryError(PermissionError): + """Raised when a requested path escapes an allowed workspace boundary.""" + + +def resolve_path(path: str | Path, workspace: str | Path | None = None, *, strict: bool = False) -> Path: + """Resolve *path*, interpreting relative paths against *workspace* when set.""" + candidate = Path(path).expanduser() + if not candidate.is_absolute() and workspace is not None: + candidate = Path(workspace).expanduser() / candidate + return candidate.resolve(strict=strict) + + +def is_path_within(path: str | Path, root: str | Path) -> bool: + """Return True when *path* resolves to *root* or a descendant of *root*.""" + try: + resolved_path = Path(path).expanduser().resolve(strict=False) + resolved_root = Path(root).expanduser().resolve(strict=False) + resolved_path.relative_to(resolved_root) + return True + except (OSError, RuntimeError, TypeError, ValueError): + return False + + +def is_path_allowed(path: str | Path, roots: Iterable[str | Path]) -> bool: + """Return True when *path* is inside any allowed root.""" + return any(is_path_within(path, root) for root in roots) + + +def require_path_within( + path: str | Path, + root: str | Path, + *, + message: str | None = None, +) -> Path: + """Resolve *path* and require it to be inside *root*.""" + resolved = Path(path).expanduser().resolve(strict=False) + if not is_path_within(resolved, root): + raise WorkspaceBoundaryError( + message + or f"Path {path} is outside allowed directory {Path(root).expanduser()}" + + WORKSPACE_BOUNDARY_NOTE + ) + return resolved + + +def resolve_allowed_path( + path: str | Path, + *, + workspace: str | Path | None = None, + allowed_root: str | Path | None = None, + extra_allowed_roots: Iterable[str | Path] | None = None, + strict: bool = False, +) -> Path: + """Resolve a path and enforce containment in allowed roots when configured.""" + resolved = resolve_path(path, workspace, strict=False) + if allowed_root is None: + return resolve_path(path, workspace, strict=strict) if strict else resolved + + roots = [allowed_root, *(extra_allowed_roots or [])] + if not is_path_allowed(resolved, roots): + raise WorkspaceBoundaryError( + f"Path {path} is outside allowed directory {Path(allowed_root).expanduser()}" + + WORKSPACE_BOUNDARY_NOTE + ) + if strict: + return resolve_path(path, workspace, strict=True) + return resolved diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index fd929134d..c1885128b 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -299,6 +299,7 @@ def build_file_edit_end_event( deleted=deleted, approximate=False, binary=(after.binary or after.oversized or after.unreadable) and not counted, + operation="delete" if tracker.before.exists and not after.exists else None, ) @@ -324,6 +325,7 @@ def build_file_edit_live_event( *, added: int, deleted: int = 0, + operation: str | None = None, ) -> dict[str, Any]: """Build an approximate in-progress event while tool-call arguments stream.""" return _event_payload( @@ -333,6 +335,7 @@ def build_file_edit_live_event( added=added, deleted=deleted, approximate=True, + operation=operation, ) @@ -454,15 +457,14 @@ class StreamingFileEditTracker: segment_end = path_matches[i + 1].start() if i + 1 < len(path_matches) else len(state.arguments) segment = state.arguments[segment_start:segment_end] - action_match = re.search(r'"action"\s*:\s*"(replace|add|delete)"', segment) + action_match = re.search(r'"action"\s*:\s*"(replace|add)"', segment) action = action_match.group(1) if action_match else "replace" old_text = _extract_json_string_prefix(segment, "old_text") or "" new_text = _extract_json_string_prefix(segment, "new_text") or "" added = _text_line_count(new_text) if action in ("replace", "add") else 0 - deleted = _text_line_count(old_text) if action in ("replace", "delete") else 0 - delete_file = action == "delete" + deleted = _text_line_count(old_text) if action == "replace" else 0 file_state = state.patch_files.get(raw_path) if file_state is None: @@ -475,8 +477,6 @@ class StreamingFileEditTracker: ) file_state = _StreamingPatchFileState(tracker=tracker) state.patch_files[raw_path] = file_state - if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable: - deleted = _text_line_count(file_state.tracker.before.text or "") if not file_state.should_emit(added, deleted, now): continue file_state.mark_emitted(added, deleted, now) @@ -916,6 +916,7 @@ def _event_payload( deleted: int, approximate: bool, binary: bool = False, + operation: str | None = None, ) -> dict[str, Any]: payload: dict[str, Any] = { "version": 1, @@ -931,6 +932,8 @@ def _event_payload( } if binary: payload["binary"] = True + if operation: + payload["operation"] = operation return payload diff --git a/nanobot/webui/mcp_presets_api.py b/nanobot/webui/mcp_presets_api.py index 40a799c1a..6ae4fe828 100644 --- a/nanobot/webui/mcp_presets_api.py +++ b/nanobot/webui/mcp_presets_api.py @@ -124,7 +124,7 @@ MCP_PRESETS: tuple[McpPreset, ...] = ( name="playwright", display_name="Playwright", category="browser", - description="Local browser inspection and automation with the official Playwright MCP server.", + description="Local browser inspection and automation with Playwright's MCP server.", docs_url="https://playwright.dev/docs/getting-started-mcp", transport="stdio", install_supported=True, @@ -216,7 +216,7 @@ MCP_PRESETS: tuple[McpPreset, ...] = ( name="microsoft-learn", display_name="Microsoft Learn", category="docs", - description="Search and fetch official Microsoft Learn documentation through Microsoft's hosted MCP server.", + description="Search and fetch Microsoft Learn documentation through Microsoft's hosted MCP server.", docs_url="https://learn.microsoft.com/en-us/training/support/mcp", transport="streamableHttp", install_supported=True, @@ -307,7 +307,7 @@ MCP_PRESETS: tuple[McpPreset, ...] = ( name="figma", display_name="Figma", category="design", - description="Read design context from Figma using the official local Dev Mode MCP server.", + description="Read design context from Figma using the local Dev Mode MCP server.", docs_url="https://help.figma.com/hc/en-us/articles/32132100833559-Guide-to-the-Figma-MCP-server", transport="streamableHttp", install_supported=True, @@ -325,7 +325,7 @@ MCP_PRESETS: tuple[McpPreset, ...] = ( name="github", display_name="GitHub", category="code", - description="Repository, issue, and pull request workflows via GitHub's official MCP server.", + description="Repository, issue, and pull request workflows via GitHub's MCP server.", docs_url="https://github.com/github/github-mcp-server", transport="stdio", install_supported=True, diff --git a/nanobot/webui/settings_api.py b/nanobot/webui/settings_api.py index 7e093a5e2..efd836b2e 100644 --- a/nanobot/webui/settings_api.py +++ b/nanobot/webui/settings_api.py @@ -7,7 +7,9 @@ settings payload shape and the allowlisted config mutations exposed to WebUI. from __future__ import annotations import re -from typing import Any +import time +from contextlib import suppress +from typing import Any, Literal from zoneinfo import ZoneInfo from nanobot.config.loader import get_config_path, load_config, save_config @@ -17,8 +19,48 @@ from nanobot.providers.image_generation import ( image_gen_provider_names, ) from nanobot.providers.registry import PROVIDERS, find_by_name +from nanobot.security.workspace_access import workspace_sandbox_status +from nanobot.webui.workspaces import ( + read_webui_default_access_mode, + write_webui_default_access_mode, +) QueryParams = dict[str, list[str]] +RuntimeSurface = Literal["browser", "native"] + +_RUNTIME_CAPABILITIES = { + "can_restart_engine": False, + "can_pick_folder": False, + "can_open_logs": False, + "can_export_diagnostics": False, +} + +_NATIVE_RUNTIME_CAPABILITIES = { + **_RUNTIME_CAPABILITIES, + "can_restart_engine": True, + "can_pick_folder": True, + "can_open_logs": True, + "can_export_diagnostics": True, +} + +_BROWSER_RESTART_BEHAVIOR_BY_SECTION = { + "appearance": "none", + "models": "none", + "providers": "none", + "runtime": "engineRestart", + "browser": "engineRestart", + "image": "engineRestart", + "apps": "engineRestart", + "advanced": "appRestart", +} + +_NATIVE_RESTART_BEHAVIOR_BY_SECTION = { + **_BROWSER_RESTART_BEHAVIOR_BY_SECTION, + "runtime": "engineRestart", + "browser": "engineRestart", + "image": "engineRestart", + "apps": "engineRestart", +} _WEB_SEARCH_PROVIDER_OPTIONS: tuple[dict[str, str], ...] = ( {"name": "duckduckgo", "label": "DuckDuckGo", "credential": "none"}, @@ -55,6 +97,70 @@ class WebUISettingsError(ValueError): self.status = status +def _normalize_surface(surface: str | None) -> RuntimeSurface: + return "native" if surface in {"native", "desktop"} else "browser" + + +def runtime_capabilities( + surface: str | None = "browser", + overrides: dict[str, Any] | None = None, +) -> dict[str, bool]: + """Return the capability flags exposed to the WebUI runtime.""" + base = ( + _NATIVE_RUNTIME_CAPABILITIES + if _normalize_surface(surface) == "native" + else _RUNTIME_CAPABILITIES + ) + result = dict(base) + for key, value in (overrides or {}).items(): + if key in result: + result[key] = bool(value) + return result + + +def restart_behavior_by_section(surface: str | None = "browser") -> dict[str, str]: + return dict( + _NATIVE_RESTART_BEHAVIOR_BY_SECTION + if _normalize_surface(surface) == "native" + else _BROWSER_RESTART_BEHAVIOR_BY_SECTION + ) + + +def decorate_settings_payload( + payload: dict[str, Any], + *, + surface: str | None = "browser", + runtime_capability_overrides: dict[str, Any] | None = None, + restart_required_sections: list[str] | None = None, + apply_state: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Attach runtime-surface metadata without changing the core settings shape.""" + surface_value = _normalize_surface(surface) + sections = restart_required_sections + if sections is None: + raw_sections = payload.get("restart_required_sections") or [] + sections = [str(section) for section in raw_sections if isinstance(section, str)] + sections = sorted(dict.fromkeys(sections)) + result = dict(payload) + result["surface"] = surface_value + result["runtime_surface"] = surface_value + result["runtime_capabilities"] = runtime_capabilities( + surface_value, + runtime_capability_overrides, + ) + result["restart_behavior_by_section"] = restart_behavior_by_section(surface_value) + result["restart_required_sections"] = sections + if sections: + result["requires_restart"] = True + else: + result["requires_restart"] = bool(result.get("requires_restart", False)) + result["apply_state"] = apply_state or { + "status": "pending" if result["requires_restart"] else "idle", + "sections": sections, + } + return result + + def _query_first(query: QueryParams, key: str) -> str | None: values = query.get(key) return values[0] if values else None @@ -83,9 +189,57 @@ def _provider_requires_api_key(spec: Any) -> bool: return True +def _oauth_provider_status(spec: Any) -> dict[str, Any]: + if not getattr(spec, "is_oauth", False): + return {"configured": False, "account": None, "expires_at": None, "login_supported": False} + + if spec.name == "openai_codex": + try: + from oauth_cli_kit import get_token as get_codex_token + except Exception: + return { + "configured": False, + "account": None, + "expires_at": None, + "login_supported": False, + } + token = None + with suppress(Exception): + token = get_codex_token() + expires_at = getattr(token, "expires", None) if token else None + return { + "configured": bool(token and token.access), + "account": getattr(token, "account_id", None) if token else None, + "expires_at": expires_at, + "login_supported": True, + } + + if spec.name == "github_copilot": + try: + from nanobot.providers.github_copilot_provider import get_github_copilot_login_status + except Exception: + return { + "configured": False, + "account": None, + "expires_at": None, + "login_supported": False, + } + token = None + with suppress(Exception): + token = get_github_copilot_login_status() + return { + "configured": bool(token and token.access and token.expires > int(time.time() * 1000)), + "account": getattr(token, "account_id", None) if token else None, + "expires_at": getattr(token, "expires", None) if token else None, + "login_supported": True, + } + + return {"configured": False, "account": None, "expires_at": None, "login_supported": False} + + def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool: if spec.is_oauth: - return True + return bool(_oauth_provider_status(spec)["configured"]) if _provider_requires_api_key(spec): return bool(provider_config.api_key) return bool( @@ -144,6 +298,7 @@ def _image_generation_provider_rows(config: Any) -> list[dict[str, Any]]: "name": name, "label": spec.label if spec is not None else name, "configured": configured, + "auth_type": "oauth" if spec is not None and spec.is_oauth else "api_key", "api_key_hint": _mask_secret_hint( getattr(provider_config, "api_key", None) ), @@ -156,7 +311,14 @@ def _image_generation_provider_rows(config: Any) -> list[dict[str, Any]]: return rows -def settings_payload(*, requires_restart: bool = False) -> dict[str, Any]: +def settings_payload( + *, + requires_restart: bool = False, + surface: str | None = "browser", + runtime_capability_overrides: dict[str, Any] | None = None, + restart_required_sections: list[str] | None = None, + apply_state: dict[str, Any] | None = None, +) -> dict[str, Any]: config = load_config() defaults = config.agents.defaults active_preset_name = defaults.model_preset or "default" @@ -179,17 +341,27 @@ def settings_payload(*, requires_restart: bool = False) -> dict[str, Any]: providers = [] for spec in PROVIDERS: provider_config = getattr(config.providers, spec.name, None) - if provider_config is None or spec.is_oauth: + if provider_config is None: continue + oauth_status = _oauth_provider_status(spec) if spec.is_oauth else None row = { "name": spec.name, "label": spec.label, - "configured": _provider_configured_for_settings(spec, provider_config), + "configured": ( + bool(oauth_status["configured"]) + if oauth_status is not None + else _provider_configured_for_settings(spec, provider_config) + ), + "auth_type": "oauth" if spec.is_oauth else "api_key", "api_key_required": _provider_requires_api_key(spec), "api_key_hint": _mask_secret_hint(provider_config.api_key), "api_base": provider_config.api_base, "default_api_base": spec.default_api_base or None, } + if oauth_status is not None: + row["oauth_account"] = oauth_status["account"] + row["oauth_expires_at"] = oauth_status["expires_at"] + row["oauth_login_supported"] = oauth_status["login_supported"] if spec.name == "openai": row["api_type"] = provider_config.api_type providers.append(row) @@ -241,7 +413,11 @@ def settings_payload(*, requires_restart: bool = False) -> dict[str, Any]: ) exec_config = config.tools.exec - return { + sandbox_status = workspace_sandbox_status( + restrict_to_workspace=config.tools.restrict_to_workspace, + workspace=config.workspace_path, + ) + payload = { "agent": { "model": effective_preset.model, "provider": selected_provider, @@ -312,6 +488,11 @@ def settings_payload(*, requires_restart: bool = False) -> dict[str, Any]: }, "advanced": { "restrict_to_workspace": config.tools.restrict_to_workspace, + "workspace_sandbox": sandbox_status.as_dict(), + "webui_allow_local_service_access": config.tools.webui_allow_local_service_access, + "allow_local_preview_access": config.tools.webui_allow_local_service_access, + "webui_default_access_mode": read_webui_default_access_mode(), + "private_service_protection_enabled": True, "ssrf_whitelist_count": len(config.tools.ssrf_whitelist), "mcp_server_count": len(config.tools.mcp_servers), "exec_enabled": exec_config.enable, @@ -320,6 +501,13 @@ def settings_payload(*, requires_restart: bool = False) -> dict[str, Any]: }, "requires_restart": requires_restart, } + return decorate_settings_payload( + payload, + surface=surface, + runtime_capability_overrides=runtime_capability_overrides, + restart_required_sections=restart_required_sections, + apply_state=apply_state, + ) def update_agent_settings(query: QueryParams) -> dict[str, Any]: @@ -444,6 +632,54 @@ def create_model_configuration(query: QueryParams) -> dict[str, Any]: return settings_payload() +def update_model_configuration(query: QueryParams) -> dict[str, Any]: + name = (_query_first(query, "name") or "").strip() + if not name or name == "default": + raise WebUISettingsError("model configuration is required") + + config = load_config() + preset = config.model_presets.get(name) + if preset is None: + raise WebUISettingsError("unknown model configuration") + + changed = False + label = _query_first_alias(query, "label", "displayName") + if label is not None: + label = label.strip() + if not label: + raise WebUISettingsError("label is required") + if preset.label != label: + preset.label = label + changed = True + + model = _query_first(query, "model") + if model is not None: + model = model.strip() + if not model: + raise WebUISettingsError("model is required") + if preset.model != model: + preset.model = model + changed = True + + provider = _query_first(query, "provider") + if provider is not None: + provider = provider.strip() + if not provider: + raise WebUISettingsError("provider is required") + _validate_configured_provider(config, provider) + if preset.provider != provider: + preset.provider = provider + changed = True + + if config.agents.defaults.model_preset != name: + config.agents.defaults.model_preset = name + changed = True + + if changed: + save_config(config) + return settings_payload() + + def update_provider_settings(query: QueryParams) -> dict[str, Any]: provider_name = (_query_first(query, "provider") or "").strip() if not provider_name: @@ -495,6 +731,114 @@ def update_provider_settings(query: QueryParams) -> dict[str, Any]: return settings_payload(requires_restart=restart_required) +def login_oauth_provider(query: QueryParams) -> dict[str, Any]: + provider_name = (_query_first(query, "provider") or "").strip() + if not provider_name: + raise WebUISettingsError("provider is required") + spec = find_by_name(provider_name) + if spec is None or not spec.is_oauth: + raise WebUISettingsError("unknown OAuth provider") + + if spec.name == "openai_codex": + try: + from oauth_cli_kit import get_token, login_oauth_interactive + except ImportError: + raise WebUISettingsError("oauth_cli_kit is not installed", status=500) from None + + token = None + with suppress(Exception): + token = get_token() + if not (token and token.access): + messages: list[str] = [] + token = login_oauth_interactive( + print_fn=lambda message: messages.append(str(message)), + prompt_fn=lambda _prompt: "", + ) + if not (token and token.access): + raise WebUISettingsError("OAuth login failed", status=401) + return settings_payload() + + if spec.name == "github_copilot": + try: + from nanobot.providers.github_copilot_provider import ( + get_github_copilot_login_status, + login_github_copilot, + ) + except ImportError: + raise WebUISettingsError("GitHub Copilot OAuth support is unavailable", status=500) from None + + token = get_github_copilot_login_status() + if not token: + token = login_github_copilot(print_fn=lambda _message: None) + if not (token and token.access): + raise WebUISettingsError("OAuth login failed", status=401) + return settings_payload() + + raise WebUISettingsError("OAuth login is not supported for this provider") + + +def logout_oauth_provider(query: QueryParams) -> dict[str, Any]: + provider_name = (_query_first(query, "provider") or "").strip() + if not provider_name: + raise WebUISettingsError("provider is required") + spec = find_by_name(provider_name) + if spec is None or not spec.is_oauth: + raise WebUISettingsError("unknown OAuth provider") + + if spec.name == "openai_codex": + try: + from oauth_cli_kit.providers import OPENAI_CODEX_PROVIDER + from oauth_cli_kit.storage import FileTokenStorage + except ImportError: + raise WebUISettingsError("oauth_cli_kit is not installed", status=500) from None + token_path = FileTokenStorage(token_filename=OPENAI_CODEX_PROVIDER.token_filename).get_token_path() + elif spec.name == "github_copilot": + try: + from nanobot.providers.github_copilot_provider import get_storage + except ImportError: + raise WebUISettingsError("GitHub Copilot OAuth support is unavailable", status=500) from None + token_path = get_storage().get_token_path() + else: + raise WebUISettingsError("OAuth logout is not supported for this provider") + + for path in (token_path, token_path.with_suffix(".lock")): + with suppress(FileNotFoundError): + path.unlink() + return settings_payload() + + +def update_network_safety_settings(query: QueryParams) -> dict[str, Any]: + raw_allow = ( + _query_first_alias(query, "webui_allow_local_service_access", "webuiAllowLocalServiceAccess") + or _query_first_alias(query, "allow_local_preview_access", "allowLocalPreviewAccess") + ) + raw_default_access_mode = _query_first_alias(query, "webui_default_access_mode", "webuiDefaultAccessMode") + if raw_allow is None and raw_default_access_mode is None: + raise WebUISettingsError("webui_allow_local_service_access or webui_default_access_mode is required") + + config = load_config() + changed = False + if raw_allow is not None: + webui_allow_local_service_access = _parse_bool(raw_allow, "webui_allow_local_service_access") + if config.tools.webui_allow_local_service_access != webui_allow_local_service_access: + config.tools.webui_allow_local_service_access = webui_allow_local_service_access + changed = True + + if changed: + save_config(config) + if raw_default_access_mode is not None: + default_access_mode = raw_default_access_mode.strip().lower() + if default_access_mode == "restricted": + default_access_mode = "default" + if default_access_mode not in {"default", "full"}: + raise WebUISettingsError("webui_default_access_mode must be default or full") + try: + write_webui_default_access_mode(default_access_mode) + except ValueError as exc: + raise WebUISettingsError(str(exc)) from exc + return settings_payload(requires_restart=changed) + + def update_web_search_settings(query: QueryParams) -> dict[str, Any]: provider_name = (_query_first(query, "provider") or "").strip().lower() provider_option = _WEB_SEARCH_PROVIDER_BY_NAME.get(provider_name) diff --git a/nanobot/webui/sidebar_state.py b/nanobot/webui/sidebar_state.py index 12d26c106..0a2f4cfcc 100644 --- a/nanobot/webui/sidebar_state.py +++ b/nanobot/webui/sidebar_state.py @@ -38,6 +38,7 @@ def default_webui_sidebar_state() -> dict[str, Any]: "pinned_keys": [], "archived_keys": [], "title_overrides": {}, + "project_name_overrides": {}, "tags_by_key": {}, "collapsed_groups": {}, "view": { @@ -136,6 +137,9 @@ def normalize_webui_sidebar_state(raw: Any) -> dict[str, Any]: state["pinned_keys"] = _clean_string_list(raw.get("pinned_keys")) state["archived_keys"] = _clean_string_list(raw.get("archived_keys")) state["title_overrides"] = _clean_title_overrides(raw.get("title_overrides")) + state["project_name_overrides"] = _clean_title_overrides( + raw.get("project_name_overrides") + ) state["tags_by_key"] = _clean_tags_by_key(raw.get("tags_by_key")) state["collapsed_groups"] = _clean_bool_map(raw.get("collapsed_groups")) state["view"] = _clean_view(raw.get("view")) @@ -190,4 +194,3 @@ def write_webui_sidebar_state(raw: dict[str, Any]) -> dict[str, Any]: finally: os.close(dir_fd) return state - diff --git a/nanobot/webui/transcript.py b/nanobot/webui/transcript.py index be525ac93..69ef4e471 100644 --- a/nanobot/webui/transcript.py +++ b/nanobot/webui/transcript.py @@ -28,6 +28,11 @@ _INLINE_MARKDOWN_IMAGE_EXTS: frozenset[str] = frozenset({ ".webp", ".gif", }) +_FILE_EDIT_TOOL_NAMES: frozenset[str] = frozenset({ + "write_file", + "edit_file", + "apply_patch", +}) def rewrite_local_markdown_images( @@ -200,6 +205,19 @@ def _tool_event_key(event: dict[str, Any]) -> str: return _format_tool_call_trace(event) or json.dumps(event, sort_keys=True, ensure_ascii=False) +def _tool_event_file_edit_key(event: dict[str, Any]) -> str | None: + call_id = event.get("call_id") + if not isinstance(call_id, str) or not call_id: + return None + name = event.get("name") + if not isinstance(name, str) or not name: + fn = event.get("function") + name = fn.get("name") if isinstance(fn, dict) else "" + if not isinstance(name, str) or name not in _FILE_EDIT_TOOL_NAMES: + return None + return f"{call_id}|{name}" + + def _merge_tool_events(previous: Any, incoming: list[dict[str, Any]]) -> list[dict[str, Any]]: if not isinstance(previous, list) or not previous: return incoming @@ -222,6 +240,87 @@ def _merge_tool_events(previous: Any, incoming: list[dict[str, Any]]) -> list[di return merged +def _file_edit_key(edit: dict[str, Any]) -> str: + call_id = str(edit.get("call_id") or "") + tool = str(edit.get("tool") or "") + if call_id: + return f"{call_id}|{tool}" + return f"{tool}|{edit.get('path') or ''}" + + +def _message_has_file_edit_for_tool_event( + message: dict[str, Any], + event: dict[str, Any], +) -> bool: + key = _tool_event_file_edit_key(event) + if not key: + return False + edits = message.get("fileEdits") + if not isinstance(edits, list): + return False + return any(isinstance(edit, dict) and _file_edit_key(edit) == key for edit in edits) + + +def _filter_covered_file_edit_tool_events( + messages: list[dict[str, Any]], + events: list[dict[str, Any]], +) -> list[dict[str, Any]]: + if not events: + return events + return [ + event + for event in events + if not any(_message_has_file_edit_for_tool_event(message, event) for message in messages) + ] + + +def _strip_covered_file_edit_tool_hints( + message: dict[str, Any], + edits: list[dict[str, Any]], +) -> dict[str, Any]: + incoming_keys = { + _file_edit_key(edit) + for edit in edits + if isinstance(edit, dict) + } + events = message.get("toolEvents") + if not incoming_keys or not isinstance(events, list): + return message + + kept_events: list[dict[str, Any]] = [] + removed_trace_lines: set[str] = set() + changed = False + for event in events: + if not isinstance(event, dict): + continue + key = _tool_event_file_edit_key(event) + if key and key in incoming_keys: + changed = True + removed_trace_lines.update(tool_trace_lines_from_events([event])) + continue + kept_events.append(event) + if not changed: + return message + + raw_traces = message.get("traces") + if isinstance(raw_traces, list): + previous_traces = [trace for trace in raw_traces if isinstance(trace, str)] + else: + content = message.get("content") + previous_traces = [content] if isinstance(content, str) and content else [] + next_traces = [trace for trace in previous_traces if trace not in removed_trace_lines] + next_message = { + **message, + "traces": next_traces, + "content": next_traces[-1] if next_traces else "", + } + if kept_events: + next_message["toolEvents"] = kept_events + else: + next_message.pop("toolEvents", None) + return next_message + + def _merge_unique_tool_trace_lines( previous_traces: list[str], lines: list[str], @@ -343,6 +442,40 @@ def replay_transcript_to_ui_messages( return None return str(last.get("id")) + def demote_interrupted_assistant(segment: str) -> None: + nonlocal buffer_message_id, buffer_parts + for i in range(len(messages) - 1, -1, -1): + candidate = messages[i] + if candidate.get("role") == "user": + break + content = candidate.get("content") + if ( + candidate.get("role") != "assistant" + or candidate.get("kind") == "trace" + or not candidate.get("isStreaming") + or not isinstance(content, str) + or not content.strip() + or candidate.get("media") + ): + continue + reasoning_parts = [ + part + for part in (candidate.get("reasoning"), content) + if isinstance(part, str) and part.strip() + ] + messages[i] = { + **candidate, + "content": "", + "reasoning": "\n\n".join(reasoning_parts), + "reasoningStreaming": False, + "isStreaming": False, + "activitySegmentId": candidate.get("activitySegmentId") or segment, + } + if buffer_message_id == candidate.get("id"): + buffer_message_id = None + buffer_parts = [] + return + def close_reasoning(prev: list[dict[str, Any]]) -> None: for i in range(len(prev) - 1, -1, -1): if prev[i].get("reasoningStreaming"): @@ -404,13 +537,6 @@ def replay_transcript_to_ui_messages( active_activity_segment_id = None active_file_edit_segment_id = None - def _file_edit_key(edit: dict[str, Any]) -> str: - call_id = str(edit.get("call_id") or "") - tool = str(edit.get("tool") or "") - if call_id: - return f"{call_id}|{tool}" - return f"{tool}|{edit.get('path') or ''}" - def find_file_edit_trace_index( segment: str | None, edits: list[dict[str, Any]], @@ -420,16 +546,23 @@ def replay_transcript_to_ui_messages( candidate = messages[i] if candidate.get("role") == "user": break - if candidate.get("kind") != "trace" or not candidate.get("fileEdits"): + if candidate.get("kind") != "trace": continue if segment and candidate.get("activitySegmentId") == segment: return i existing_edits = candidate.get("fileEdits") - if not isinstance(existing_edits, list): - continue - for existing in existing_edits: - if isinstance(existing, dict) and _file_edit_key(existing) in incoming_keys: - return i + if isinstance(existing_edits, list): + for existing in existing_edits: + if isinstance(existing, dict) and _file_edit_key(existing) in incoming_keys: + return i + existing_tool_events = candidate.get("toolEvents") + if isinstance(existing_tool_events, list): + for event in existing_tool_events: + if not isinstance(event, dict): + continue + key = _tool_event_file_edit_key(event) + if key and key in incoming_keys: + return i return None def upsert_file_edits(edits: list[dict[str, Any]], idx: int) -> None: @@ -437,11 +570,16 @@ def replay_transcript_to_ui_messages( if not edits: return segment = active_file_edit_segment_id + if not segment: + segment = _new_activity_segment(activate=False) + active_file_edit_segment_id = segment + demote_interrupted_assistant(segment) target_index = find_file_edit_trace_index(segment, edits) if target_index is not None: last = messages[target_index] segment = str(last.get("activitySegmentId") or segment or _new_activity_segment(activate=False)) active_file_edit_segment_id = segment + last = _strip_covered_file_edit_tool_hints(last, edits) else: if not segment: segment = _new_activity_segment(activate=False) @@ -620,12 +758,21 @@ def replay_transcript_to_ui_messages( continue if kind in ("tool_hint", "progress"): structured_events = _normalize_tool_events(rec.get("tool_events")) - structured = tool_trace_lines_from_events(rec.get("tool_events")) + visible_structured_events = _filter_covered_file_edit_tool_events(messages, structured_events) + structured = tool_trace_lines_from_events(visible_structured_events) text = rec.get("text") - trace_lines = structured if structured else ([text] if isinstance(text, str) and text else []) + if structured: + trace_lines = structured + elif structured_events: + trace_lines = [] + elif isinstance(text, str) and text: + trace_lines = [text] + else: + trace_lines = [] if not trace_lines: continue segment = _ensure_activity_segment() + demote_interrupted_assistant(segment) last = messages[-1] if messages else None if ( last @@ -636,7 +783,7 @@ def replay_transcript_to_ui_messages( prev_traces = list(last.get("traces") or [last.get("content")]) if structured: merged_traces, added = _merge_unique_tool_trace_lines(prev_traces, structured) - if not added and not structured_events: + if not added and not visible_structured_events: continue else: merged_traces = prev_traces + trace_lines @@ -644,8 +791,8 @@ def replay_transcript_to_ui_messages( **last, "traces": merged_traces, "content": merged_traces[-1], - "toolEvents": _merge_tool_events(last.get("toolEvents"), structured_events) - if structured_events + "toolEvents": _merge_tool_events(last.get("toolEvents"), visible_structured_events) + if visible_structured_events else last.get("toolEvents"), "activitySegmentId": last.get("activitySegmentId") or segment, } @@ -658,7 +805,7 @@ def replay_transcript_to_ui_messages( "kind": "trace", "content": trace_lines[-1], "traces": trace_lines, - **({"toolEvents": structured_events} if structured_events else {}), + **({"toolEvents": visible_structured_events} if visible_structured_events else {}), "activitySegmentId": segment, "createdAt": _ts_base + idx, }, diff --git a/nanobot/webui/workspaces.py b/nanobot/webui/workspaces.py new file mode 100644 index 000000000..774f2857f --- /dev/null +++ b/nanobot/webui/workspaces.py @@ -0,0 +1,283 @@ +"""Persisted WebUI project workspace state.""" + +from __future__ import annotations + +import json +import os +import time +from pathlib import Path +from typing import Any + +from loguru import logger + +from nanobot.config.paths import get_webui_dir +from nanobot.security.workspace_access import ( + WORKSPACE_SCOPE_METADATA_KEY, + WorkspaceScope, + WorkspaceScopeError, + build_workspace_scope, + default_workspace_scope, + validate_workspace_scope_payload, +) + +WEBUI_WORKSPACE_STATE_SCHEMA_VERSION = 1 +_MAX_STATE_FILE_BYTES = 128 * 1024 +_DEFAULT_ACCESS_MODES = {"default", "full"} +_LEGACY_RESTRICTED_DEFAULT_ACCESS_MODE = "restricted" +_WEBUI_SCOPE_CHANNEL = "websocket" + + +def webui_workspace_state_path() -> Path: + return get_webui_dir() / "workspace-state.json" + + +def default_webui_workspace_state() -> dict[str, Any]: + return { + "schema_version": WEBUI_WORKSPACE_STATE_SCHEMA_VERSION, + "default_access_mode": "default", + "updated_at": None, + } + + +def normalize_webui_workspace_state(raw: Any) -> dict[str, Any]: + if not isinstance(raw, dict): + raw = {} + state = default_webui_workspace_state() + updated_at = raw.get("updated_at") + state["updated_at"] = updated_at if isinstance(updated_at, str) else None + default_access_mode = raw.get("default_access_mode") + if default_access_mode in _DEFAULT_ACCESS_MODES: + state["default_access_mode"] = default_access_mode + return state + + +def read_webui_workspace_state() -> dict[str, Any]: + path = webui_workspace_state_path() + if not path.is_file(): + return default_webui_workspace_state() + try: + if path.stat().st_size > _MAX_STATE_FILE_BYTES: + logger.warning("webui workspace state too large, ignoring: {}", path) + return default_webui_workspace_state() + with open(path, encoding="utf-8") as f: + raw = json.load(f) + except (OSError, json.JSONDecodeError) as e: + logger.warning("read webui workspace state failed {}: {}", path, e) + return default_webui_workspace_state() + return normalize_webui_workspace_state(raw) + + +def write_webui_workspace_state(raw: dict[str, Any]) -> dict[str, Any]: + state = normalize_webui_workspace_state(raw) + state["updated_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + encoded = json.dumps( + state, + ensure_ascii=False, + indent=2, + sort_keys=True, + ).encode("utf-8") + if len(encoded) > _MAX_STATE_FILE_BYTES: + raise ValueError("workspace state is too large") + + path = webui_workspace_state_path() + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(".json.tmp") + with open(tmp, "wb") as f: + f.write(encoded) + f.write(b"\n") + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) + try: + dir_fd = os.open(path.parent, os.O_RDONLY) + except OSError: + return state + try: + os.fsync(dir_fd) + finally: + os.close(dir_fd) + return state + + +def read_webui_default_access_mode() -> str: + state = read_webui_workspace_state() + mode = state.get("default_access_mode") + return mode if mode in _DEFAULT_ACCESS_MODES else "default" + + +def write_webui_default_access_mode(mode: str) -> bool: + if mode == _LEGACY_RESTRICTED_DEFAULT_ACCESS_MODE: + mode = "default" + if mode not in _DEFAULT_ACCESS_MODES: + raise ValueError("default access mode must be default or full") + state = read_webui_workspace_state() + changed = state.get("default_access_mode") != mode + if changed: + state["default_access_mode"] = mode + write_webui_workspace_state(state) + return changed + + +def default_scope_for_webui( + default_workspace: Path, + default_restrict_to_workspace: bool, +) -> WorkspaceScope: + mode = read_webui_default_access_mode() + if mode == "default": + return default_workspace_scope( + default_workspace, + default_restrict_to_workspace, + source_channel=_WEBUI_SCOPE_CHANNEL, + ) + return build_workspace_scope(default_workspace, mode, source_channel=_WEBUI_SCOPE_CHANNEL) + + +def workspaces_payload( + *, + default_workspace: Path, + default_restrict_to_workspace: bool, + controls_available: bool, +) -> dict[str, Any]: + default_access_mode = read_webui_default_access_mode() + default_scope = ( + default_workspace_scope( + default_workspace, + default_restrict_to_workspace, + source_channel=_WEBUI_SCOPE_CHANNEL, + ) + if default_access_mode == "default" + else build_workspace_scope(default_workspace, default_access_mode, source_channel=_WEBUI_SCOPE_CHANNEL) + ) + return { + "schema_version": WEBUI_WORKSPACE_STATE_SCHEMA_VERSION, + "default_access_mode": default_access_mode, + "default_scope": default_scope.payload(), + "controls": { + "can_change_project": controls_available, + "can_use_full_access": controls_available, + }, + } + + +class WebUIWorkspaceController: + """Own WebUI project scope persistence and validation.""" + + def __init__( + self, + *, + session_manager: Any | None, + default_workspace: Path, + default_restrict_to_workspace: bool, + ) -> None: + self._sessions = session_manager + self._default_workspace = default_workspace + self._default_restrict_to_workspace = default_restrict_to_workspace + + def default_scope(self) -> WorkspaceScope: + return default_scope_for_webui( + self._default_workspace, + self._default_restrict_to_workspace, + ) + + def scope_for_session_key(self, session_key: str) -> WorkspaceScope: + if self._sessions is None: + return self.default_scope() + data = self._sessions.read_session_file(session_key) + metadata = data.get("metadata", {}) if isinstance(data, dict) else {} + if not isinstance(metadata, dict) or WORKSPACE_SCOPE_METADATA_KEY not in metadata: + return self.default_scope() + try: + return validate_workspace_scope_payload( + metadata.get(WORKSPACE_SCOPE_METADATA_KEY), + default_workspace=self._default_workspace, + default_restrict_to_workspace=self._default_restrict_to_workspace, + source_channel=_WEBUI_SCOPE_CHANNEL, + ) + except WorkspaceScopeError: + return self.default_scope() + + def payload(self, *, controls_available: bool) -> dict[str, Any]: + return workspaces_payload( + default_workspace=self._default_workspace, + default_restrict_to_workspace=self._default_restrict_to_workspace, + controls_available=controls_available, + ) + + def scope_from_envelope( + self, + envelope: dict[str, Any], + *, + session_key: str | None, + controls_available: bool, + ) -> WorkspaceScope: + raw = envelope.get(WORKSPACE_SCOPE_METADATA_KEY) + if raw is None and session_key: + scope = self.scope_for_session_key(session_key) + elif raw is None: + scope = self.default_scope() + else: + scope = validate_workspace_scope_payload( + raw, + default_workspace=self._default_workspace, + default_restrict_to_workspace=self._default_restrict_to_workspace, + source_channel=_WEBUI_SCOPE_CHANNEL, + ) + if not controls_available and scope.metadata() != self.default_scope().metadata(): + raise WorkspaceScopeError("workspace controls are localhost-only", status=403) + return scope + + def scope_for_new_chat( + self, + envelope: dict[str, Any], + *, + controls_available: bool, + ) -> WorkspaceScope: + return self.scope_from_envelope( + envelope, + session_key=None, + controls_available=controls_available, + ) + + def scope_for_set_request( + self, + envelope: dict[str, Any], + *, + chat_id: str, + chat_running: bool, + controls_available: bool, + ) -> WorkspaceScope: + if chat_running: + raise WorkspaceScopeError("chat_running", status=409) + return self.scope_from_envelope( + envelope, + session_key=f"websocket:{chat_id}", + controls_available=controls_available, + ) + + def scope_for_message( + self, + envelope: dict[str, Any], + *, + chat_id: str, + chat_running: bool, + controls_available: bool, + ) -> WorkspaceScope: + scope = self.scope_from_envelope( + envelope, + session_key=f"websocket:{chat_id}", + controls_available=controls_available, + ) + if ( + WORKSPACE_SCOPE_METADATA_KEY in envelope + and chat_running + and scope.metadata() != self.scope_for_session_key(f"websocket:{chat_id}").metadata() + ): + raise WorkspaceScopeError("chat_running", status=409) + return scope + + def persist_scope(self, chat_id: str, scope: WorkspaceScope) -> None: + if self._sessions is not None: + session = self._sessions.get_or_create(f"websocket:{chat_id}") + session.metadata["webui"] = True + session.metadata[WORKSPACE_SCOPE_METADATA_KEY] = scope.metadata() + self._sessions.save(session) diff --git a/tests/agent/test_loop_direct_websocket_status.py b/tests/agent/test_loop_direct_websocket_status.py new file mode 100644 index 000000000..2c8581e2a --- /dev/null +++ b/tests/agent/test_loop_direct_websocket_status.py @@ -0,0 +1,55 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import GenerationSettings, LLMResponse + + +def _make_loop(tmp_path): + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.generation = GenerationSettings(max_tokens=0) + provider.estimate_prompt_tokens.return_value = (0, "test-counter") + response = LLMResponse(content="done", tool_calls=[]) + provider.chat_with_retry = AsyncMock(return_value=response) + provider.chat_stream_with_retry = AsyncMock(return_value=response) + + loop = AgentLoop( + bus=bus, + provider=provider, + workspace=tmp_path, + model="test-model", + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + return loop + + +@pytest.mark.asyncio +async def test_process_direct_websocket_clears_run_status(tmp_path) -> None: + loop = _make_loop(tmp_path) + + response = await loop.process_direct( + "deliver reminder", + session_key="cron:reminder-1", + channel="websocket", + chat_id="chat-1", + ) + + assert response is not None + assert response.content == "done" + + events = [] + while loop.bus.outbound_size: + events.append(await loop.bus.consume_outbound()) + + statuses = [ + event.metadata + for event in events + if event.metadata.get("_goal_status") is True + ] + assert [status["goal_status"] for status in statuses] == ["running", "idle"] + assert isinstance(statuses[0].get("started_at"), float) + assert "started_at" not in statuses[1] diff --git a/tests/agent/test_workspace_scope.py b/tests/agent/test_workspace_scope.py new file mode 100644 index 000000000..504392436 --- /dev/null +++ b/tests/agent/test_workspace_scope.py @@ -0,0 +1,344 @@ +import json +import time +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from nanobot.agent.tools.cli_apps import CliAppsTool +from nanobot.agent.tools.filesystem import ReadFileTool +from nanobot.agent.tools.image_generation import ImageGenerationError, ImageGenerationTool +from nanobot.agent.tools.message import MessageTool +from nanobot.agent.tools.shell import ExecTool +from nanobot.agent.tools.spawn import SpawnTool +from nanobot.security.workspace_access import ( + WORKSPACE_SCOPE_METADATA_KEY, + WorkspaceScopeError, + bind_workspace_scope, + default_workspace_scope, + reset_workspace_scope, + validate_workspace_scope_payload, + workspace_scope_from_metadata, +) +from nanobot.apps.cli.service import CliAppManager, CliAppsRuntimeConfig +from nanobot.config.schema import ImageGenerationToolConfig, ProviderConfig + +PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" + b"\x00\x00\x00\x01\x08\x04\x00\x00\x00\xb5\x1c\x0c\x02" + b"\x00\x00\x00\x0bIDATx\xdacd\xfc\xff\x1f\x00\x03\x03" + b"\x02\x00\xef\xbf\xa7\xdb\x00\x00\x00\x00IEND\xaeB`\x82" +) + + +def test_workspace_scope_defaults_match_legacy_config(tmp_path: Path) -> None: + unrestricted = default_workspace_scope(tmp_path, restrict_to_workspace=False) + restricted = default_workspace_scope(tmp_path, restrict_to_workspace=True) + + assert unrestricted.project_path == tmp_path.resolve() + assert unrestricted.access_mode == "full" + assert unrestricted.restrict_to_workspace is False + assert restricted.access_mode == "restricted" + assert restricted.restrict_to_workspace is True + + +def test_workspace_scope_rejects_invalid_project_path(tmp_path: Path) -> None: + with pytest.raises(WorkspaceScopeError, match="absolute"): + validate_workspace_scope_payload( + {"project_path": "relative/project", "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + + with pytest.raises(WorkspaceScopeError, match="existing directory"): + validate_workspace_scope_payload( + {"project_path": str(tmp_path / "missing"), "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + + +def test_workspace_scope_accepts_home_relative_project_path( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + home = tmp_path / "home" + project = home / "Desktop" / "Photos" + project.mkdir(parents=True) + monkeypatch.setenv("HOME", str(home)) + monkeypatch.setenv("USERPROFILE", str(home)) + + scope = validate_workspace_scope_payload( + {"project_path": "~/Desktop/Photos", "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + + assert scope.project_path == project.resolve() + assert scope.metadata()["project_path"] == str(project.resolve()) + + +def test_workspace_scope_metadata_falls_back_for_stale_session(tmp_path: Path) -> None: + scope = workspace_scope_from_metadata( + { + WORKSPACE_SCOPE_METADATA_KEY: { + "project_path": str(tmp_path / "missing"), + "access_mode": "restricted", + } + }, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + + assert scope.project_path == tmp_path.resolve() + assert scope.access_mode == "full" + + +@pytest.mark.asyncio +async def test_filesystem_tool_uses_current_restricted_workspace_scope(tmp_path: Path) -> None: + project = tmp_path / "project" + project.mkdir() + outside = tmp_path / "outside.txt" + outside.write_text("nope") + inside = project / "inside.txt" + inside.write_text("ok") + tool = ReadFileTool(workspace=tmp_path, restrict_to_workspace=False) + scope = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + token = bind_workspace_scope(scope) + try: + assert "ok" in await tool.execute(path="inside.txt") + assert "outside allowed directory" in await tool.execute(path=str(outside)) + finally: + reset_workspace_scope(token) + + +@pytest.mark.asyncio +async def test_exec_tool_uses_scope_project_as_default_cwd(tmp_path: Path) -> None: + project = tmp_path / "project" + project.mkdir() + tool = ExecTool(working_dir=str(tmp_path), restrict_to_workspace=False, timeout=5) + scope = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + token = bind_workspace_scope(scope) + try: + result = await tool.execute(command="printf ok > scoped-marker.txt") + finally: + reset_workspace_scope(token) + + assert "Exit code: 0" in result + assert (project / "scoped-marker.txt").read_text() == "ok" + + +@pytest.mark.asyncio +async def test_exec_full_scope_allows_explicit_cwd_outside_project(tmp_path: Path) -> None: + project = tmp_path / "project" + outside = tmp_path / "outside" + project.mkdir() + outside.mkdir() + tool = ExecTool(working_dir=str(tmp_path), restrict_to_workspace=True, timeout=5) + scope = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "full"}, + default_workspace=tmp_path, + default_restrict_to_workspace=True, + ) + token = bind_workspace_scope(scope) + try: + result = await tool.execute(command="printf ok > outside-marker.txt", working_dir=str(outside)) + finally: + reset_workspace_scope(token) + + assert "Exit code: 0" in result + assert (outside / "outside-marker.txt").read_text() == "ok" + + +def test_image_reference_scope_restricted_blocks_outside_and_full_allows(tmp_path: Path) -> None: + project = tmp_path / "project" + outside = tmp_path / "outside" + project.mkdir() + outside.mkdir() + ref = outside / "ref.png" + ref.write_bytes(PNG_BYTES) + tool = ImageGenerationTool( + workspace=tmp_path, + config=ImageGenerationToolConfig(enabled=True), + provider_config=ProviderConfig(api_key="sk-test"), + ) + + restricted = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + token = bind_workspace_scope(restricted) + try: + with pytest.raises(ImageGenerationError, match="inside the workspace"): + tool._resolve_reference_image(str(ref)) + finally: + reset_workspace_scope(token) + + full = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "full"}, + default_workspace=tmp_path, + default_restrict_to_workspace=True, + ) + token = bind_workspace_scope(full) + try: + assert tool._resolve_reference_image(str(ref)) == str(ref.resolve()) + finally: + reset_workspace_scope(token) + + +def test_message_media_scope_restricted_blocks_outside_and_full_allows(tmp_path: Path) -> None: + project = tmp_path / "project" + outside = tmp_path / "outside" + project.mkdir() + outside.mkdir() + media = outside / "shot.png" + media.write_bytes(PNG_BYTES) + tool = MessageTool(workspace=tmp_path, restrict_to_workspace=True) + + restricted = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + token = bind_workspace_scope(restricted) + try: + with pytest.raises(PermissionError): + tool._resolve_media([str(media)]) + finally: + reset_workspace_scope(token) + + full = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "full"}, + default_workspace=tmp_path, + default_restrict_to_workspace=True, + ) + token = bind_workspace_scope(full) + try: + assert tool._resolve_media([str(media)]) == [str(media)] + finally: + reset_workspace_scope(token) + + +@pytest.mark.asyncio +async def test_cli_app_scope_controls_working_dir( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project = tmp_path / "project" + outside = tmp_path / "outside" + data_dir = tmp_path / "data" + project.mkdir() + outside.mkdir() + registry = { + "meta": {}, + "clis": [ + { + "name": "demo", + "display_name": "Demo", + "version": "1.0", + "description": "demo", + "category": "test", + "install_cmd": "pip install demo", + "entry_point": "demo-cli", + } + ], + } + data_dir.mkdir() + (data_dir / "harness_registry_cache.json").write_text( + json.dumps({"_cached_at": time.time(), "data": registry}), + encoding="utf-8", + ) + (data_dir / "public_registry_cache.json").write_text( + json.dumps({"_cached_at": time.time(), "data": {"meta": {}, "clis": []}}), + encoding="utf-8", + ) + CliAppManager(workspace=project, data_dir=data_dir)._save_installed( + {"demo": {"entry_point": "demo-cli"}} + ) + monkeypatch.setattr("nanobot.apps.cli.service.get_runtime_subdir", lambda _name: data_dir) + monkeypatch.setattr( + "nanobot.apps.cli.service.shutil.which", + lambda entry: "/usr/bin/demo-cli" if entry == "demo-cli" else None, + ) + + seen: dict[str, str] = {} + + def fake_run(argv, **kwargs): + seen["cwd"] = kwargs["cwd"] + return SimpleNamespace(returncode=0, stdout="ok", stderr="") + + monkeypatch.setattr("nanobot.apps.cli.service.subprocess.run", fake_run) + tool = CliAppsTool( + workspace=tmp_path, + restrict_to_workspace=True, + runtime=CliAppsRuntimeConfig(run_timeout=5), + ) + + restricted = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + token = bind_workspace_scope(restricted) + try: + blocked = await tool.execute(name="demo", working_dir=str(outside)) + finally: + reset_workspace_scope(token) + assert "outside the configured workspace" in blocked + + full = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "full"}, + default_workspace=tmp_path, + default_restrict_to_workspace=True, + ) + token = bind_workspace_scope(full) + try: + result = await tool.execute(name="demo", working_dir=str(outside)) + finally: + reset_workspace_scope(token) + assert "CLI app 'demo' exited 0" in result + assert seen["cwd"] == str(outside.resolve()) + + +@pytest.mark.asyncio +async def test_spawn_tool_forwards_current_workspace_scope(tmp_path: Path) -> None: + project = tmp_path / "project" + project.mkdir() + scope = validate_workspace_scope_payload( + {"project_path": str(project), "access_mode": "restricted"}, + default_workspace=tmp_path, + default_restrict_to_workspace=False, + ) + + class Manager: + max_concurrent_subagents = 4 + + def __init__(self) -> None: + self.seen = None + + def get_running_count(self) -> int: + return 0 + + async def spawn(self, **kwargs): + self.seen = kwargs + return "spawned" + + manager = Manager() + tool = SpawnTool(manager) # type: ignore[arg-type] + token = bind_workspace_scope(scope) + try: + result = await tool.execute(task="inspect") + finally: + reset_workspace_scope(token) + + assert result == "spawned" + assert manager.seen["workspace_scope"] == scope diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index 2b0bb76cd..ba2b29411 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -30,6 +30,8 @@ from nanobot.channels.websocket import ( ) from nanobot.config.loader import load_config, save_config from nanobot.config.schema import Config, ModelPresetConfig +from nanobot.session import webui_turns as wth +from nanobot.session.manager import SessionManager from nanobot.webui.settings_api import settings_payload, update_provider_settings # -- Shared helpers (aligned with test_websocket_integration.py) --------------- @@ -57,6 +59,14 @@ def bus() -> MagicMock: return b +@pytest.fixture(autouse=True) +def isolate_webui_workspace_state(tmp_path, monkeypatch) -> None: + monkeypatch.setattr( + "nanobot.webui.workspaces.get_webui_dir", + lambda: tmp_path / "webui", + ) + + async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response: """Run GET in a thread to avoid blocking the asyncio loop shared with websockets.""" return await asyncio.to_thread( @@ -64,6 +74,15 @@ async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Re ) +async def _recv_ws_event(client: Any, event: str) -> dict[str, Any]: + """Receive until a specific websocket event appears.""" + for _ in range(10): + payload = json.loads(await client.recv()) + if payload.get("event") == event: + return payload + raise AssertionError(f"websocket event {event!r} was not received") + + def test_normalize_http_path_strips_trailing_slash_except_root() -> None: assert _normalize_http_path("/chat/") == "/chat" assert _normalize_http_path("/chat?x=1") == "/chat" @@ -81,6 +100,19 @@ def test_normalize_config_path_matches_request() -> None: assert _normalize_config_path("/") == "/" +def test_websocket_config_accepts_absolute_unix_socket(tmp_path) -> None: + socket_path = tmp_path / "engine.sock" + + cfg = WebSocketConfig(unix_socket_path=str(socket_path)) + + assert cfg.unix_socket_path == str(socket_path) + + +def test_websocket_config_rejects_relative_unix_socket() -> None: + with pytest.raises(ValueError, match="absolute path"): + WebSocketConfig(unix_socket_path="engine.sock") + + def test_parse_query_extracts_token_and_client_id() -> None: query = _parse_query("/?token=secret&client_id=u1") assert query.get("token") == ["secret"] @@ -204,6 +236,291 @@ async def test_plain_websocket_message_does_not_mark_webui(bus: MagicMock) -> No assert "webui" not in msg.metadata +@pytest.mark.asyncio +async def test_webui_message_scope_inherits_persisted_session_scope( + bus: MagicMock, + tmp_path, +) -> None: + default_workspace = tmp_path / "default" + project = tmp_path / "project" + default_workspace.mkdir() + project.mkdir() + sessions = SessionManager(tmp_path / "sessions") + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, + bus, + session_manager=sessions, + workspace_path=default_workspace, + restrict_to_workspace=True, + ) + conn = AsyncMock() + conn.remote_address = ("127.0.0.1", 50123) + + await channel._dispatch_envelope( + conn, + "webui-client", + { + "type": "set_workspace_scope", + "chat_id": "chat-scope", + "workspace_scope": { + "project_path": str(project), + "access_mode": "full", + }, + }, + ) + await channel._dispatch_envelope( + conn, + "webui-client", + {"type": "message", "chat_id": "chat-scope", "content": "hello", "webui": True}, + ) + + msg = bus.publish_inbound.await_args.args[0] + assert msg.metadata["workspace_scope"] == { + "project_path": str(project.resolve()), + "access_mode": "full", + } + + +@pytest.mark.asyncio +async def test_webui_scope_expands_home_project_path( + bus: MagicMock, + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + default_workspace = tmp_path / "default" + home = tmp_path / "home" + project = home / "Desktop" / "Photos" + default_workspace.mkdir() + project.mkdir(parents=True) + monkeypatch.setenv("HOME", str(home)) + monkeypatch.setenv("USERPROFILE", str(home)) + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, + bus, + session_manager=SessionManager(tmp_path / "sessions"), + workspace_path=default_workspace, + restrict_to_workspace=True, + ) + conn = AsyncMock() + conn.remote_address = ("127.0.0.1", 50123) + + await channel._dispatch_envelope( + conn, + "webui-client", + { + "type": "set_workspace_scope", + "chat_id": "chat-scope", + "workspace_scope": { + "project_path": "~/Desktop/Photos", + "access_mode": "restricted", + }, + }, + ) + await channel._dispatch_envelope( + conn, + "webui-client", + {"type": "message", "chat_id": "chat-scope", "content": "hello", "webui": True}, + ) + + msg = bus.publish_inbound.await_args.args[0] + assert msg.metadata["workspace_scope"] == { + "project_path": str(project.resolve()), + "access_mode": "restricted", + } + + +@pytest.mark.asyncio +async def test_webui_scope_rejects_missing_project_path(bus: MagicMock, tmp_path) -> None: + default_workspace = tmp_path / "default" + default_workspace.mkdir() + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, + bus, + session_manager=SessionManager(tmp_path / "sessions"), + workspace_path=default_workspace, + ) + conn = AsyncMock() + conn.remote_address = ("127.0.0.1", 50123) + + await channel._dispatch_envelope( + conn, + "webui-client", + { + "type": "set_workspace_scope", + "chat_id": "chat-scope", + "workspace_scope": { + "project_path": str(tmp_path / "missing"), + "access_mode": "restricted", + }, + }, + ) + + conn.send.assert_awaited() + payload = json.loads(conn.send.await_args.args[0]) + assert payload["event"] == "error" + assert payload["detail"] == "workspace_scope_rejected" + bus.publish_inbound.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_webui_scope_rejects_running_scope_change(bus: MagicMock, tmp_path) -> None: + default_workspace = tmp_path / "default" + project = tmp_path / "project" + other = tmp_path / "other" + default_workspace.mkdir() + project.mkdir() + other.mkdir() + sessions = SessionManager(tmp_path / "sessions") + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, + bus, + session_manager=sessions, + workspace_path=default_workspace, + restrict_to_workspace=True, + ) + conn = AsyncMock() + conn.remote_address = ("127.0.0.1", 50123) + + await channel._dispatch_envelope( + conn, + "webui-client", + { + "type": "set_workspace_scope", + "chat_id": "chat-running", + "workspace_scope": { + "project_path": str(project), + "access_mode": "restricted", + }, + }, + ) + wth._WEBSOCKET_TURN_WALL_STARTED_AT["chat-running"] = 123.0 + try: + await channel._dispatch_envelope( + conn, + "webui-client", + { + "type": "message", + "chat_id": "chat-running", + "content": "hello", + "webui": True, + "workspace_scope": { + "project_path": str(other), + "access_mode": "full", + }, + }, + ) + finally: + wth._WEBSOCKET_TURN_WALL_STARTED_AT.clear() + + payload = json.loads(conn.send.await_args.args[0]) + assert payload["event"] == "error" + assert payload["detail"] == "workspace_scope_rejected" + assert payload["reason"] == "chat_running" + assert payload["chat_id"] == "chat-running" + bus.publish_inbound.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_webui_set_workspace_scope_rejects_running_chat(bus: MagicMock, tmp_path) -> None: + default_workspace = tmp_path / "default" + project = tmp_path / "project" + other = tmp_path / "other" + default_workspace.mkdir() + project.mkdir() + other.mkdir() + sessions = SessionManager(tmp_path / "sessions") + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, + bus, + session_manager=sessions, + workspace_path=default_workspace, + restrict_to_workspace=True, + ) + conn = AsyncMock() + conn.remote_address = ("127.0.0.1", 50123) + + await channel._dispatch_envelope( + conn, + "webui-client", + { + "type": "set_workspace_scope", + "chat_id": "chat-running", + "workspace_scope": { + "project_path": str(project), + "access_mode": "restricted", + }, + }, + ) + conn.send.reset_mock() + + wth._WEBSOCKET_TURN_WALL_STARTED_AT["chat-running"] = 123.0 + try: + await channel._dispatch_envelope( + conn, + "webui-client", + { + "type": "set_workspace_scope", + "chat_id": "chat-running", + "workspace_scope": { + "project_path": str(other), + "access_mode": "full", + }, + }, + ) + finally: + wth._WEBSOCKET_TURN_WALL_STARTED_AT.clear() + + payload = json.loads(conn.send.await_args.args[0]) + assert payload["event"] == "error" + assert payload["detail"] == "workspace_scope_rejected" + assert payload["reason"] == "chat_running" + assert payload["chat_id"] == "chat-running" + + saved = sessions.read_session_file("websocket:chat-running") + assert saved["metadata"]["workspace_scope"] == { + "project_path": str(project.resolve()), + "access_mode": "restricted", + } + + +@pytest.mark.asyncio +async def test_webui_scope_rejects_non_loopback_custom_scope(bus: MagicMock, tmp_path) -> None: + default_workspace = tmp_path / "default" + project = tmp_path / "project" + default_workspace.mkdir() + project.mkdir() + sessions = SessionManager(tmp_path / "sessions") + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, + bus, + session_manager=sessions, + workspace_path=default_workspace, + restrict_to_workspace=True, + ) + conn = AsyncMock() + conn.remote_address = ("203.0.113.8", 50123) + + await channel._dispatch_envelope( + conn, + "webui-client", + { + "type": "set_workspace_scope", + "chat_id": "chat-remote", + "workspace_scope": { + "project_path": str(project), + "access_mode": "full", + }, + }, + ) + + payload = json.loads(conn.send.await_args.args[0]) + assert payload["event"] == "error" + assert payload["detail"] == "workspace_scope_rejected" + assert payload["reason"] == "workspace controls are localhost-only" + assert payload["chat_id"] == "chat-remote" + assert sessions.read_session_file("websocket:chat-remote") is None + + @pytest.mark.asyncio async def test_send_delivers_json_message_with_media_and_reply() -> None: bus = MagicMock() @@ -1067,6 +1384,15 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( config.tools.web.search.api_key = "brave-secret" save_config(config, config_path) monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + monkeypatch.setattr( + "nanobot.webui.settings_api._oauth_provider_status", + lambda _spec: { + "configured": False, + "account": None, + "expires_at": None, + "login_supported": True, + }, + ) channel = _ch(bus, port=port) channel._api_tokens["tok"] = time.monotonic() + 300 @@ -1103,6 +1429,8 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( assert providers["atomic_chat"]["configured"] is False assert providers["atomic_chat"]["api_key_required"] is False assert providers["atomic_chat"]["default_api_base"] == "http://localhost:1337/v1" + assert providers["openai_codex"]["auth_type"] == "oauth" + assert providers["openai_codex"]["configured"] is False assert body["agent"]["has_api_key"] is True assert body["web_search"]["provider"] == "brave" assert body["web_search"]["api_key_hint"] == "brav••••cret" @@ -1121,18 +1449,29 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( } assert image_providers["openrouter"]["label"] == "OpenRouter" assert image_providers["openrouter"]["configured"] is False - assert image_providers["openai_codex"]["configured"] is True + assert image_providers["openai_codex"]["auth_type"] == "oauth" + assert image_providers["openai_codex"]["configured"] is False assert image_providers["gemini"]["label"] == "Gemini" assert body["runtime"]["config_path"] == str(config_path) workspace_path = body["runtime"]["workspace_path"].replace("\\", "/") assert workspace_path.endswith("/.nanobot/workspace") assert body["runtime"]["gateway_port"] == 18790 assert body["advanced"]["exec_enabled"] is True + assert body["advanced"]["webui_allow_local_service_access"] is True + assert body["advanced"]["webui_default_access_mode"] == "default" + assert body["advanced"]["private_service_protection_enabled"] is True assert body["advanced"]["mcp_server_count"] == 0 assert body["restart_required_sections"] == [] assert "secret-key" not in settings.text assert "brave-secret" not in settings.text + unknown_api = await _http_get( + f"http://127.0.0.1:{port}/api/settings/model-configurations/missing", + headers={"Authorization": "Bearer tok"}, + ) + assert unknown_api.status_code == 404 + assert "" not in unknown_api.text.lower() + provider_updated = await _http_get( "http://127.0.0.1:" f"{port}/api/settings/provider/update?provider=openrouter" @@ -1204,6 +1543,21 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( assert created_presets["fast-writing"]["label"] == "Fast writing" assert created_presets["fast-writing"]["provider"] == "openai" + updated_preset = await _http_get( + "http://127.0.0.1:" + f"{port}/api/settings/model-configurations/update" + "?name=fast-writing&label=Codex&provider=openai&model=openai%2Fgpt-5.5", + headers={"Authorization": "Bearer tok"}, + ) + assert updated_preset.status_code == 200 + updated_preset_body = updated_preset.json() + assert updated_preset_body["agent"]["model_preset"] == "fast-writing" + assert updated_preset_body["agent"]["model"] == "openai/gpt-5.5" + updated_presets = { + preset["name"]: preset for preset in updated_preset_body["model_presets"] + } + assert updated_presets["fast-writing"]["label"] == "Codex" + duplicate_preset = await _http_get( "http://127.0.0.1:" f"{port}/api/settings/model-configurations/create" @@ -1222,13 +1576,26 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( assert search_updated.status_code == 200 search_body = search_updated.json() assert search_body["requires_restart"] is True - assert search_body["restart_required_sections"] == ["runtime", "web"] + assert search_body["restart_required_sections"] == ["browser", "runtime"] assert search_body["web_search"]["provider"] == "searxng" assert search_body["web_search"]["api_key_hint"] is None assert search_body["web_search"]["base_url"] == "https://search.example.com" assert search_body["web_search"]["max_results"] == 8 assert search_body["web"]["fetch"]["use_jina_reader"] is False + network_safety_updated = await _http_get( + "http://127.0.0.1:" + f"{port}/api/settings/network-safety/update?webui_allow_local_service_access=false&webui_default_access_mode=full", + headers={"Authorization": "Bearer tok"}, + ) + assert network_safety_updated.status_code == 200 + network_safety_body = network_safety_updated.json() + assert network_safety_body["requires_restart"] is True + assert network_safety_body["restart_required_sections"] == ["browser", "runtime"] + assert network_safety_body["advanced"]["webui_allow_local_service_access"] is False + assert network_safety_body["advanced"]["webui_default_access_mode"] == "full" + assert network_safety_body["advanced"]["private_service_protection_enabled"] is True + image_updated = await _http_get( "http://127.0.0.1:" f"{port}/api/settings/image-generation/update?enabled=true" @@ -1240,7 +1607,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( assert image_updated.status_code == 200 image_body = image_updated.json() assert image_body["requires_restart"] is True - assert image_body["restart_required_sections"] == ["image", "runtime", "web"] + assert image_body["restart_required_sections"] == ["browser", "image", "runtime"] assert image_body["image_generation"]["enabled"] is True assert image_body["image_generation"]["model"] == "openai/gpt-image-1" assert image_body["image_generation"]["default_aspect_ratio"] == "16:9" @@ -1256,9 +1623,9 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( assert image_provider_updated.status_code == 200 assert image_provider_updated.json()["requires_restart"] is True assert image_provider_updated.json()["restart_required_sections"] == [ + "browser", "image", "runtime", - "web", ] assert "sk-or-next" not in image_provider_updated.text @@ -1280,8 +1647,8 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( assert saved.agents.defaults.model == "atomic_chat/test" assert saved.agents.defaults.provider == "atomic_chat" assert saved.agents.defaults.model_preset == "fast-writing" - assert saved.model_presets["fast-writing"].label == "Fast writing" - assert saved.model_presets["fast-writing"].model == "openai/gpt-4.1-mini" + assert saved.model_presets["fast-writing"].label == "Codex" + assert saved.model_presets["fast-writing"].model == "openai/gpt-5.5" assert saved.model_presets["fast-writing"].provider == "openai" assert saved.agents.defaults.timezone == "Asia/Shanghai" assert saved.agents.defaults.bot_name == "Nano" @@ -1296,6 +1663,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( assert saved.tools.web.search.max_results == 8 assert saved.tools.web.search.timeout == 45 assert saved.tools.web.fetch.use_jina_reader is False + assert saved.tools.webui_allow_local_service_access is False assert saved.tools.image_generation.enabled is True assert saved.tools.image_generation.provider == "openrouter" assert saved.tools.image_generation.model == "openai/gpt-image-1" @@ -1335,6 +1703,43 @@ async def test_commands_api_returns_slash_command_metadata(bus: MagicMock) -> No await server_task +@pytest.mark.asyncio +async def test_bootstrap_exposes_native_surface(bus: MagicMock) -> None: + port = 29893 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/ws", + "tokenIssueSecret": "native-secret", + "websocketRequiresToken": True, + }, + bus, + runtime_surface="native", + runtime_capabilities_overrides={"can_pick_folder": True}, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + response = await _http_get( + f"http://127.0.0.1:{port}/webui/bootstrap", + headers={"X-Nanobot-Auth": "native-secret"}, + ) + assert response.status_code == 200 + body = response.json() + assert body["runtime_surface"] == "native" + assert body["runtime_capabilities"]["can_pick_folder"] is True + assert body["runtime_capabilities"]["can_restart_engine"] is True + assert body["token"].startswith("nbwt_") + finally: + await channel.stop() + await server_task + + def test_settings_payload_normalizes_camel_case_provider( bus: MagicMock, monkeypatch, @@ -1365,6 +1770,44 @@ def test_settings_payload_exposes_api_type_only_for_openai(monkeypatch, tmp_path assert "api_type" not in providers["custom"] +def test_settings_payload_reports_workspace_sandbox(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "config.json" + config = Config() + config.tools.restrict_to_workspace = True + save_config(config, config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + monkeypatch.setenv("NANOBOT_SANDBOX_ENFORCED", "macos_app_sandbox") + + body = settings_payload() + sandbox = body["advanced"]["workspace_sandbox"] + + assert sandbox["restrict_to_workspace"] is True + assert sandbox["level"] == "system" + assert sandbox["enforced"] is True + assert sandbox["provider"] == "macos_app_sandbox" + assert sandbox["provider_label"] == "macOS App Sandbox" + + +def test_settings_payload_includes_native_runtime_surface(monkeypatch, tmp_path) -> None: + config_path = tmp_path / "config.json" + save_config(Config(), config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + + body = settings_payload( + surface="native", + runtime_capability_overrides={"can_open_logs": True}, + restart_required_sections=["runtime"], + ) + + assert body["surface"] == "native" + assert body["runtime_surface"] == "native" + assert body["runtime_capabilities"]["can_open_logs"] is True + assert body["runtime_capabilities"]["can_restart_engine"] is True + assert body["restart_behavior_by_section"]["runtime"] == "engineRestart" + assert body["requires_restart"] is True + assert body["apply_state"] == {"status": "pending", "sections": ["runtime"]} + + def test_update_provider_settings_ignores_api_type_for_non_openai(monkeypatch, tmp_path) -> None: config_path = tmp_path / "config.json" save_config(Config(), config_path) @@ -1671,6 +2114,8 @@ async def test_multiplex_new_chat_roundtrip(bus: MagicMock) -> None: OutboundMessage(channel="websocket", chat_id=new_chat, content="ok") ) reply = json.loads(await client.recv()) + if reply["event"] == "session_updated": + reply = json.loads(await client.recv()) assert reply["event"] == "message" assert reply["chat_id"] == new_chat assert reply["text"] == "ok" @@ -1691,16 +2136,16 @@ async def test_multiplex_two_chats_isolated(bus: MagicMock) -> None: await client.recv() # ready await client.send(json.dumps({"type": "new_chat"})) - chat_a = json.loads(await client.recv())["chat_id"] + chat_a = (await _recv_ws_event(client, "attached"))["chat_id"] await client.send(json.dumps({"type": "new_chat"})) - chat_b = json.loads(await client.recv())["chat_id"] + chat_b = (await _recv_ws_event(client, "attached"))["chat_id"] assert chat_a != chat_b # Push A → client sees A only (FIFO over the single WS). await channel.send( OutboundMessage(channel="websocket", chat_id=chat_a, content="for-A") ) - msg_a = json.loads(await client.recv()) + msg_a = await _recv_ws_event(client, "message") assert msg_a["chat_id"] == chat_a assert msg_a["text"] == "for-A" @@ -1708,7 +2153,7 @@ async def test_multiplex_two_chats_isolated(bus: MagicMock) -> None: await channel.send( OutboundMessage(channel="websocket", chat_id=chat_b, content="for-B") ) - msg_b = json.loads(await client.recv()) + msg_b = await _recv_ws_event(client, "message") assert msg_b["chat_id"] == chat_b assert msg_b["text"] == "for-B" finally: @@ -1830,6 +2275,9 @@ def test_sessions_list_includes_active_run_started_at() -> None: assert resp.status_code == 200 body = json.loads(resp.body.decode()) + workspace_scope = body["sessions"][0].pop("workspace_scope") + assert workspace_scope["project_path"] == str(channel._workspace_path) + assert workspace_scope["access_mode"] in {"restricted", "full"} assert body["sessions"] == [ { "key": "websocket:chat-1", diff --git a/tests/channels/test_websocket_http_routes.py b/tests/channels/test_websocket_http_routes.py index 18baf700e..ffb6d2c01 100644 --- a/tests/channels/test_websocket_http_routes.py +++ b/tests/channels/test_websocket_http_routes.py @@ -95,6 +95,7 @@ async def test_bootstrap_returns_token_for_localhost( body = resp.json() assert body["token"].startswith("nbwt_") assert body["ws_path"] == "/" + assert body["ws_url"] == "ws://127.0.0.1:29901/" assert body["expires_in"] > 0 assert isinstance(body.get("model_name"), str) finally: @@ -734,6 +735,17 @@ def test_bootstrap_accepts_static_token_as_secret(bus: MagicMock) -> None: assert body["token"].startswith("nbwt_") +def test_bootstrap_ws_url_uses_forwarded_https_host(bus: MagicMock) -> None: + channel = _ch(bus, host="127.0.0.1", port=29931) + resp = channel._handle_bootstrap( + _LOCAL, + _FakeReq({"Host": "nanobot.example", "X-Forwarded-Proto": "https"}), + ) + assert resp.status_code == 200 + body = json.loads(resp.body) + assert body["ws_url"] == "wss://nanobot.example/" + + def test_localhost_without_auth_is_valid(bus: MagicMock) -> None: channel = _ch(bus, host="127.0.0.1") resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS) diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 940ce9865..5060a3a76 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -1521,6 +1521,35 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) assert "port 18792" in result.stdout +def test_configure_desktop_gateway_forces_local_websocket_only() -> None: + from nanobot.cli.commands import _configure_desktop_gateway + + config = Config() + config.channels.__pydantic_extra__ = { + "telegram": {"enabled": True, "token": "x"}, + "websocket": {"enabled": False, "port": 8765}, + } + + _configure_desktop_gateway( + config, + webui_port=29888, + webui_socket="/tmp/nanobot-test.sock", + token_issue_secret="secret", + ) + + extras = config.channels.__pydantic_extra__ or {} + assert config.gateway.host == "127.0.0.1" + assert config.gateway.port == 29888 + assert config.gateway.heartbeat.enabled is False + assert extras["telegram"]["enabled"] is False + assert extras["websocket"]["enabled"] is True + assert extras["websocket"]["host"] == "127.0.0.1" + assert extras["websocket"]["port"] == 29888 + assert extras["websocket"]["unix_socket_path"] == "/tmp/nanobot-test.sock" + assert extras["websocket"]["token_issue_secret"] == "secret" + assert extras["websocket"]["websocket_requires_token"] is True + + def test_gateway_health_endpoint_binds_and_serves_expected_responses( monkeypatch, tmp_path: Path ) -> None: diff --git a/tests/cli_apps/test_service.py b/tests/cli_apps/test_service.py index 0c42505f4..4b32f3fd3 100644 --- a/tests/cli_apps/test_service.py +++ b/tests/cli_apps/test_service.py @@ -143,6 +143,8 @@ def test_payload_merges_catalog_and_marks_unsupported_installs(tmp_path: Path) - assert apps["gimp"]["install_supported"] is True assert apps["gimp"]["source"] == "harness+public" assert apps["gimp"]["description"] == "Public duplicate entry" + assert apps["feishu"]["description"] == "Lark CLI" + assert apps["feishu"]["manifest"]["description"] == "Lark CLI" assert apps["clibrowser"]["install_supported"] is False assert apps["jimeng"]["install_supported"] is False assert apps["suno"]["install_supported"] is True diff --git a/tests/config/test_config_migration.py b/tests/config/test_config_migration.py index b27926ec0..1fd68b685 100644 --- a/tests/config/test_config_migration.py +++ b/tests/config/test_config_migration.py @@ -223,3 +223,24 @@ def test_load_config_resets_ssrf_whitelist_when_next_config_is_empty(tmp_path) - with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("ts.local", ["100.100.1.1"])): ok, _ = validate_url_target("http://ts.local/api") assert not ok + + +def test_load_config_defaults_local_service_access_to_enabled(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"tools": {}}), encoding="utf-8") + + config = load_config(config_path) + + assert config.tools.webui_allow_local_service_access is True + + +def test_load_config_accepts_legacy_local_preview_access(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps({"tools": {"allowLocalPreviewAccess": False}}), + encoding="utf-8", + ) + + config = load_config(config_path) + + assert config.tools.webui_allow_local_service_access is False diff --git a/tests/providers/test_openai_codex_provider.py b/tests/providers/test_openai_codex_provider.py index b3089e994..e1994555c 100644 --- a/tests/providers/test_openai_codex_provider.py +++ b/tests/providers/test_openai_codex_provider.py @@ -12,6 +12,7 @@ import nanobot.providers.base as provider_base from nanobot.providers.openai_codex_provider import ( OpenAICodexProvider, _codex_error_response, + _build_reasoning_options, _CodexHTTPError, _friendly_error, _request_codex, @@ -128,11 +129,12 @@ async def test_codex_prompt_cache_key_uses_stable_conversation_prefix(monkeypatc body, verify, on_content_delta=None, + on_thinking_delta=None, on_tool_call_delta=None, ): - _ = on_tool_call_delta + _ = on_thinking_delta, on_tool_call_delta bodies.append(body) - return "ok", [], "stop" + return "ok", [], "stop", None monkeypatch.setattr("nanobot.providers.openai_codex_provider._request_codex", fake_request) @@ -257,7 +259,7 @@ async def test_codex_retry_uses_structured_timeout_metadata(monkeypatch) -> None calls += 1 if calls == 1: raise httpx.ReadTimeout("") - return "ok", [], "stop" + return "ok", [], "stop", None async def fake_sleep(delay: float) -> None: delays.append(delay) @@ -397,3 +399,56 @@ def test_codex_429_classification_uses_raw_error_semantics( error_type, error_code = provider_base.LLMProvider._extract_error_type_code(raw) assert _should_retry_status(429, error_type, error_code, raw) is expected_retry + + +def test_codex_reasoning_options_request_summary_without_forcing_effort() -> None: + assert _build_reasoning_options(None) == {"summary": "auto"} + assert _build_reasoning_options("high") == {"summary": "auto", "effort": "high"} + assert _build_reasoning_options("none") == {"effort": "none"} + + +@pytest.mark.asyncio +async def test_codex_stream_surfaces_reasoning_summary(monkeypatch) -> None: + monkeypatch.setattr( + "nanobot.providers.openai_codex_provider.get_codex_token", + lambda: SimpleNamespace(account_id="acct", access="token"), + ) + + async def fake_request( + url, + headers, + body, + verify, + on_content_delta=None, + on_thinking_delta=None, + on_tool_call_delta=None, + ): + _ = url, headers, verify, on_tool_call_delta + assert body["reasoning"] == {"summary": "auto", "effort": "medium"} + if on_content_delta: + await on_content_delta("answer") + if on_thinking_delta: + await on_thinking_delta("summary") + return "answer", [], "stop", "summary" + + monkeypatch.setattr("nanobot.providers.openai_codex_provider._request_codex", fake_request) + + provider = OpenAICodexProvider() + content_deltas: list[str] = [] + thinking_deltas: list[str] = [] + + response = await provider.chat_stream( + [{"role": "user", "content": "hi"}], + reasoning_effort="medium", + on_content_delta=lambda delta: _append(content_deltas, delta), + on_thinking_delta=lambda delta: _append(thinking_deltas, delta), + ) + + assert content_deltas == ["answer"] + assert thinking_deltas == ["summary"] + assert response.content == "answer" + assert response.reasoning_content == "summary" + + +async def _append(target: list[str], value: str) -> None: + target.append(value) diff --git a/tests/providers/test_openai_responses.py b/tests/providers/test_openai_responses.py index 36040db58..49ae86493 100644 --- a/tests/providers/test_openai_responses.py +++ b/tests/providers/test_openai_responses.py @@ -1,10 +1,10 @@ """Tests for the shared openai_responses converters and parsers.""" +import json from unittest.mock import MagicMock, patch import pytest -from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.openai_responses.converters import ( convert_messages, convert_tools, @@ -13,6 +13,8 @@ from nanobot.providers.openai_responses.converters import ( ) from nanobot.providers.openai_responses.parsing import ( consume_sdk_stream, + consume_sse, + consume_sse_with_reasoning, map_finish_reason, parse_response_output, ) @@ -434,6 +436,166 @@ class TestParseResponseOutput: assert result.usage["total_tokens"] == 150 +# ====================================================================== +# parsing - consume_sse +# ====================================================================== + + +class _SseResponse: + def __init__(self, events: list[dict]): + self._events = events + + async def aiter_lines(self): + for event in self._events: + yield f"data: {json.dumps(event)}" + yield "" + + +class TestConsumeSse: + @pytest.mark.asyncio + async def test_legacy_consume_sse_returns_three_tuple(self): + response = _SseResponse([ + {"type": "response.output_text.delta", "delta": "hi"}, + {"type": "response.completed", "response": {"status": "completed"}}, + ]) + + content, tool_calls, finish_reason = await consume_sse(response) + + assert content == "hi" + assert tool_calls == [] + assert finish_reason == "stop" + + @pytest.mark.asyncio + async def test_reasoning_summary_delta_extracted(self): + response = _SseResponse([ + {"type": "response.reasoning_summary_text.delta", "delta": "thinking "}, + {"type": "response.reasoning_summary_text.delta", "delta": "briefly"}, + {"type": "response.output_text.delta", "delta": "answer"}, + {"type": "response.completed", "response": {"status": "completed"}}, + ]) + deltas: list[str] = [] + + async def on_reasoning(delta: str) -> None: + deltas.append(delta) + + content, tool_calls, finish_reason, reasoning = await consume_sse_with_reasoning( + response, + on_reasoning_delta=on_reasoning, + ) + + assert content == "answer" + assert tool_calls == [] + assert finish_reason == "stop" + assert reasoning == "thinking briefly" + assert deltas == ["thinking ", "briefly"] + + @pytest.mark.asyncio + async def test_reasoning_summary_from_completed_response(self): + response = _SseResponse([ + { + "type": "response.completed", + "response": { + "status": "completed", + "output": [ + {"type": "reasoning", "summary": [ + {"type": "summary_text", "text": "cached "}, + {"type": "summary_text", "text": "summary"}, + ]}, + ], + }, + }, + ]) + + _, _, _, reasoning = await consume_sse_with_reasoning(response) + + assert reasoning == "cached summary" + + @pytest.mark.asyncio + async def test_reasoning_summary_from_done_item(self): + response = _SseResponse([ + { + "type": "response.output_item.done", + "item": { + "type": "reasoning", + "summary": [{"type": "summary_text", "text": "done summary"}], + }, + }, + {"type": "response.completed", "response": {"status": "completed", "output": []}}, + ]) + deltas: list[str] = [] + + async def on_reasoning(delta: str) -> None: + deltas.append(delta) + + _, _, _, reasoning = await consume_sse_with_reasoning( + response, + on_reasoning_delta=on_reasoning, + ) + + assert reasoning == "done summary" + assert deltas == ["done summary"] + + @pytest.mark.asyncio + async def test_reasoning_summary_part_done_extracted(self): + response = _SseResponse([ + { + "type": "response.reasoning_summary_part.done", + "part": {"type": "summary_text", "text": "part summary"}, + }, + {"type": "response.completed", "response": {"status": "completed"}}, + ]) + + _, _, _, reasoning = await consume_sse_with_reasoning(response) + + assert reasoning == "part summary" + + @pytest.mark.asyncio + async def test_tool_call_done_arguments_callback(self): + response = _SseResponse([ + { + "type": "response.output_item.added", + "item": { + "type": "function_call", + "call_id": "c1", + "id": "fc1", + "name": "write_file", + "arguments": "", + }, + }, + { + "type": "response.function_call_arguments.done", + "call_id": "c1", + "arguments": '{"path":"a.txt","content":"hello\\n"}', + }, + { + "type": "response.output_item.done", + "item": { + "type": "function_call", + "call_id": "c1", + "id": "fc1", + "name": "write_file", + "arguments": '{"path":"a.txt","content":"hello\\n"}', + }, + }, + {"type": "response.completed", "response": {"status": "completed"}}, + ]) + deltas: list[dict] = [] + + async def cb(delta: dict) -> None: + deltas.append(delta) + + await consume_sse_with_reasoning(response, on_tool_call_delta=cb) + + assert deltas == [ + {"call_id": "c1", "name": "write_file", "arguments_delta": ""}, + { + "call_id": "c1", + "name": "write_file", + "arguments": '{"path":"a.txt","content":"hello\\n"}', + }, + ] + + # ====================================================================== # parsing - consume_sdk_stream # ====================================================================== @@ -544,6 +706,46 @@ class TestConsumeSdkStream: "arguments_delta": '{"path":"a.txt","content":"', }, {"call_id": "c1", "name": "write_file", "arguments_delta": "hello\\n"}, + { + "call_id": "c1", + "name": "write_file", + "arguments": '{"path":"a.txt","content":"hello\\n"}', + }, + ] + + @pytest.mark.asyncio + async def test_tool_call_done_item_arguments_callback_without_delta(self): + item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") + item_added.name = "write_file" + ev1 = MagicMock(type="response.output_item.added", item=item_added) + item_done = MagicMock( + type="function_call", + call_id="c1", + id="fc1", + arguments='{"path":"late.txt","content":"done\\n"}', + ) + item_done.name = "write_file" + ev2 = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev3 = MagicMock(type="response.completed", response=resp_obj) + deltas: list[dict] = [] + + async def cb(delta: dict) -> None: + deltas.append(delta) + + async def stream(): + for e in [ev1, ev2, ev3]: + yield e + + await consume_sdk_stream(stream(), on_tool_call_delta=cb) + + assert deltas == [ + {"call_id": "c1", "name": "write_file", "arguments_delta": ""}, + { + "call_id": "c1", + "name": "write_file", + "arguments": '{"path":"late.txt","content":"done\\n"}', + }, ] @pytest.mark.asyncio diff --git a/tests/security/test_security_network.py b/tests/security/test_security_network.py index a22c7e223..de4b90d3f 100644 --- a/tests/security/test_security_network.py +++ b/tests/security/test_security_network.py @@ -49,7 +49,7 @@ def test_rejects_missing_domain(): ]) def test_blocks_private_ipv4(ip: str, label: str): with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("evil.com", [ip])): - ok, err = validate_url_target(f"http://evil.com/path") + ok, err = validate_url_target("http://evil.com/path") assert not ok, f"Should block {label} ({ip})" assert "private" in err.lower() or "blocked" in err.lower() @@ -92,6 +92,21 @@ def test_detects_wget_localhost(): assert contains_internal_url("wget http://localhost:8080/secret") +def test_loopback_exception_allows_literal_localhost_only(): + with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("localhost", ["127.0.0.1"])): + assert not contains_internal_url("curl http://localhost:8765/", allow_loopback=True) + + +def test_loopback_exception_rejects_public_name_resolving_to_loopback(): + with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("example.com", ["127.0.0.1"])): + assert contains_internal_url("curl http://example.com:8765/", allow_loopback=True) + + +def test_loopback_exception_rejects_metadata(): + with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("169.254.169.254", ["169.254.169.254"])): + assert contains_internal_url("curl http://169.254.169.254/latest/meta-data/", allow_loopback=True) + + def test_allows_normal_curl(): with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("example.com", ["93.184.216.34"])): assert not contains_internal_url("curl https://example.com/api/data") diff --git a/tests/security/test_workspace_policy.py b/tests/security/test_workspace_policy.py new file mode 100644 index 000000000..0ed89dcc1 --- /dev/null +++ b/tests/security/test_workspace_policy.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from nanobot.security.workspace_policy import ( + WorkspaceBoundaryError, + is_path_within, + resolve_allowed_path, +) + + +def test_resolve_allowed_path_accepts_workspace_relative_path(tmp_path: Path) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + target = workspace / "src" / "main.py" + target.parent.mkdir() + target.write_text("print('ok')", encoding="utf-8") + + resolved = resolve_allowed_path("src/main.py", workspace=workspace, allowed_root=workspace) + + assert resolved == target.resolve() + + +def test_resolve_allowed_path_blocks_parent_traversal(tmp_path: Path) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + outside = tmp_path / "secret.txt" + outside.write_text("secret", encoding="utf-8") + + with pytest.raises(WorkspaceBoundaryError, match="outside allowed directory"): + resolve_allowed_path("../secret.txt", workspace=workspace, allowed_root=workspace) + + +def test_resolve_allowed_path_blocks_symlink_escape(tmp_path: Path) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + outside = tmp_path / "outside" + outside.mkdir() + secret = outside / "secret.txt" + secret.write_text("secret", encoding="utf-8") + link = workspace / "linked-secret.txt" + try: + link.symlink_to(secret) + except OSError as exc: + pytest.skip(f"symlink creation is unavailable: {exc}") + + assert not is_path_within(link, workspace) + with pytest.raises(WorkspaceBoundaryError): + resolve_allowed_path("linked-secret.txt", workspace=workspace, allowed_root=workspace) + + +def test_resolve_allowed_path_allows_extra_root(tmp_path: Path) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + media = tmp_path / "media" + media.mkdir() + image = media / "image.png" + image.write_bytes(b"\x89PNG\r\n\x1a\n") + + resolved = resolve_allowed_path( + image, + workspace=workspace, + allowed_root=workspace, + extra_allowed_roots=[media], + ) + + assert resolved == image.resolve() diff --git a/tests/security/test_workspace_sandbox.py b/tests/security/test_workspace_sandbox.py new file mode 100644 index 000000000..1ddd55c1b --- /dev/null +++ b/tests/security/test_workspace_sandbox.py @@ -0,0 +1,68 @@ +from pathlib import Path + +from nanobot.security.workspace_access import workspace_sandbox_status + + +def test_workspace_sandbox_disabled(tmp_path: Path) -> None: + status = workspace_sandbox_status( + restrict_to_workspace=False, + workspace=tmp_path, + environ={}, + ) + + assert status.level == "off" + assert status.enforced is False + assert status.provider == "none" + assert status.as_dict()["workspace_root"] == str(tmp_path.resolve()) + + +def test_workspace_sandbox_application_guard(tmp_path: Path) -> None: + status = workspace_sandbox_status( + restrict_to_workspace=True, + workspace=tmp_path, + environ={}, + ) + + assert status.level == "application" + assert status.enforced is False + assert status.provider == "none" + assert "application-level" in status.summary + + +def test_workspace_sandbox_system_provider_from_compact_env(tmp_path: Path) -> None: + status = workspace_sandbox_status( + restrict_to_workspace=True, + workspace=tmp_path, + environ={"NANOBOT_SANDBOX_ENFORCED": "macos_app_sandbox"}, + ) + + assert status.level == "system" + assert status.enforced is True + assert status.provider == "macos_app_sandbox" + assert status.provider_label == "macOS App Sandbox" + + +def test_workspace_sandbox_system_provider_from_boolean_env(tmp_path: Path) -> None: + status = workspace_sandbox_status( + restrict_to_workspace=True, + workspace=tmp_path, + environ={ + "NANOBOT_WORKSPACE_SANDBOX_ENFORCED": "true", + "NANOBOT_WORKSPACE_SANDBOX_PROVIDER": "macOS App Sandbox", + }, + ) + + assert status.level == "system" + assert status.enforced is True + assert status.provider == "macos_app_sandbox" + + +def test_workspace_sandbox_false_env_does_not_enforce(tmp_path: Path) -> None: + status = workspace_sandbox_status( + restrict_to_workspace=True, + workspace=tmp_path, + environ={"NANOBOT_WORKSPACE_SANDBOX_ENFORCED": "false"}, + ) + + assert status.level == "application" + assert status.enforced is False diff --git a/tests/test_tool_contextvars.py b/tests/test_tool_contextvars.py index c5f02326d..e2b7f66ab 100644 --- a/tests/test_tool_contextvars.py +++ b/tests/test_tool_contextvars.py @@ -65,6 +65,7 @@ async def test_spawn_tool_keeps_task_local_context() -> None: session_key: str, origin_message_id: str | None = None, temperature: float | None = None, + workspace_scope=None, ) -> str: seen.append((origin_channel, origin_chat_id, session_key)) return f"{origin_channel}:{origin_chat_id}:{task}" @@ -178,6 +179,7 @@ async def test_spawn_tool_basic_set_context_and_execute() -> None: session_key, origin_message_id=None, temperature=None, + workspace_scope=None, ): seen.append((origin_channel, origin_chat_id, session_key)) return f"ok: {task}" @@ -211,6 +213,7 @@ async def test_spawn_tool_default_values_without_set_context() -> None: session_key, origin_message_id=None, temperature=None, + workspace_scope=None, ): seen.append((origin_channel, origin_chat_id, session_key)) return "ok" diff --git a/tests/tools/test_apply_patch_tool.py b/tests/tools/test_apply_patch_tool.py index 2ba247368..9ddc35a85 100644 --- a/tests/tools/test_apply_patch_tool.py +++ b/tests/tools/test_apply_patch_tool.py @@ -89,7 +89,7 @@ def test_apply_patch_edits_add_to_existing_file(tmp_path): ) -def test_apply_patch_edits_delete(tmp_path): +def test_apply_patch_rejects_delete_action(tmp_path): target = tmp_path / "utils.py" target.write_text("def unused():\n pass\ndef used():\n return 1\n") tool = ApplyPatchTool(workspace=tmp_path) @@ -106,51 +106,8 @@ def test_apply_patch_edits_delete(tmp_path): ) ) - assert "update utils.py" in result - assert target.read_text() == "def used():\n return 1\n" - - -def test_apply_patch_edits_delete_entire_file(tmp_path): - target = tmp_path / "obsolete.txt" - target.write_text("remove me\n") - tool = ApplyPatchTool(workspace=tmp_path) - - result = asyncio.run( - tool.execute( - edits=[ - { - "path": "obsolete.txt", - "action": "delete", - "old_text": "remove me\n", - } - ] - ) - ) - - assert "delete obsolete.txt" in result - assert not target.exists() - - -def test_apply_patch_edits_delete_substring_with_surrounding_whitespace(tmp_path): - target = tmp_path / "keep_whitespace.txt" - target.write_text(" token \n") - tool = ApplyPatchTool(workspace=tmp_path) - - result = asyncio.run( - tool.execute( - edits=[ - { - "path": "keep_whitespace.txt", - "action": "delete", - "old_text": "token", - } - ] - ) - ) - - assert "update keep_whitespace.txt" in result - assert target.exists() - assert target.read_text() == " \n" + assert "unknown action: delete" in result + assert target.read_text() == "def unused():\n pass\ndef used():\n return 1\n" def test_apply_patch_edits_batch_multiple_files(tmp_path): @@ -319,8 +276,9 @@ def test_apply_patch_edits_rolls_back_when_late_operation_fails(tmp_path): }, { "path": "missing.txt", - "action": "delete", + "action": "replace", "old_text": "remove me", + "new_text": "removed", }, ] ) diff --git a/tests/tools/test_exec_security.py b/tests/tools/test_exec_security.py index fb6731f03..7540f87b8 100644 --- a/tests/tools/test_exec_security.py +++ b/tests/tools/test_exec_security.py @@ -9,6 +9,7 @@ from unittest.mock import patch import pytest from nanobot.agent.tools.shell import ExecTool +from nanobot.security.workspace_access import bind_workspace_scope, build_workspace_scope, reset_workspace_scope def _fake_resolve_private(hostname, port, family=0, type_=0): @@ -42,6 +43,70 @@ async def test_exec_blocks_wget_localhost(): assert "Error" in result +def test_exec_full_workspace_scope_allows_loopback(tmp_path): + tool = ExecTool(working_dir=str(tmp_path)) + scope = build_workspace_scope(tmp_path, "full", source_channel="websocket") + token = bind_workspace_scope(scope) + try: + with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost): + error = tool._guard_command("curl http://localhost:8765/", str(tmp_path)) + finally: + reset_workspace_scope(token) + assert error is None + + +def test_exec_core_full_workspace_scope_blocks_loopback(tmp_path): + tool = ExecTool(working_dir=str(tmp_path)) + scope = build_workspace_scope(tmp_path, "full") + token = bind_workspace_scope(scope) + try: + with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost): + error = tool._guard_command("curl http://localhost:8765/", str(tmp_path)) + finally: + reset_workspace_scope(token) + assert error is not None + assert "internal/private" in error + + +def test_exec_full_workspace_scope_blocks_loopback_when_local_service_disabled(tmp_path): + tool = ExecTool(working_dir=str(tmp_path), webui_allow_local_service_access=False) + scope = build_workspace_scope(tmp_path, "full", source_channel="websocket") + token = bind_workspace_scope(scope) + try: + with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost): + error = tool._guard_command("curl http://localhost:8765/", str(tmp_path)) + finally: + reset_workspace_scope(token) + assert error is not None + assert "internal/private" in error + + +def test_exec_restricted_workspace_scope_blocks_loopback(tmp_path): + tool = ExecTool(working_dir=str(tmp_path)) + scope = build_workspace_scope(tmp_path, "restricted", source_channel="websocket") + token = bind_workspace_scope(scope) + try: + with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost): + error = tool._guard_command("curl http://localhost:8765/", str(tmp_path)) + finally: + reset_workspace_scope(token) + assert error is not None + assert "internal/private" in error + + +def test_exec_full_workspace_scope_still_blocks_metadata(tmp_path): + tool = ExecTool(working_dir=str(tmp_path)) + scope = build_workspace_scope(tmp_path, "full", source_channel="websocket") + token = bind_workspace_scope(scope) + try: + with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private): + error = tool._guard_command("curl http://169.254.169.254/latest/meta-data/", str(tmp_path)) + finally: + reset_workspace_scope(token) + assert error is not None + assert "internal/private" in error + + @pytest.mark.asyncio async def test_exec_allows_normal_commands(): tool = ExecTool(timeout=5) diff --git a/tests/tools/test_tool_loader.py b/tests/tools/test_tool_loader.py index bfe35d910..8499b14fe 100644 --- a/tests/tools/test_tool_loader.py +++ b/tests/tools/test_tool_loader.py @@ -5,8 +5,6 @@ from dataclasses import fields from typing import Any from unittest.mock import MagicMock -import pytest - from nanobot.agent.tools.base import Tool @@ -115,6 +113,31 @@ def test_discover_skips_private_classes(): assert not cls.__name__.startswith("_") +def test_loader_registers_exec_with_real_tools_config(tmp_path): + """Real config objects catch bad ctx.config attribute paths that mocks hide.""" + from types import SimpleNamespace + + from nanobot.agent.tools.registry import ToolRegistry + from nanobot.config.schema import ToolsConfig + + ctx = ToolContext( + config=ToolsConfig(), + workspace=str(tmp_path), + bus=None, + subagent_manager=SimpleNamespace( + get_running_count=lambda: 0, + max_concurrent_subagents=4, + ), + cron_service=None, + timezone="UTC", + ) + registry = ToolRegistry() + registered = ToolLoader().load(ctx, registry) + + assert "exec" in registered + assert registry.has("exec") + + # --- Task 4: _FsTool.create() --- from pathlib import Path diff --git a/tests/tools/test_web_fetch_security.py b/tests/tools/test_web_fetch_security.py index 448229edc..89ff9d9f9 100644 --- a/tests/tools/test_web_fetch_security.py +++ b/tests/tools/test_web_fetch_security.py @@ -12,6 +12,7 @@ import pytest from nanobot.agent.tools import web as web_module from nanobot.agent.tools.web import WebFetchTool from nanobot.config.schema import WebFetchConfig +from nanobot.security.workspace_access import bind_workspace_scope, build_workspace_scope, reset_workspace_scope _REAL_GETADDRINFO = socket.getaddrinfo @@ -45,6 +46,24 @@ async def test_web_fetch_blocks_localhost(): assert "error" in data +@pytest.mark.asyncio +async def test_web_fetch_blocks_localhost_even_in_full_workspace_scope(tmp_path): + tool = WebFetchTool() + scope = build_workspace_scope(tmp_path, "full") + + def _resolve_localhost(hostname, port, family=0, type_=0): + return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))] + + token = bind_workspace_scope(scope) + try: + with patch("nanobot.security.network.socket.getaddrinfo", _resolve_localhost): + result = await tool.execute(url="http://localhost/admin") + finally: + reset_workspace_scope(token) + data = json.loads(result) + assert "error" in data + + @pytest.mark.asyncio async def test_web_fetch_result_contains_untrusted_flag(): """When fetch succeeds, result JSON must include untrusted=True and the banner.""" diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index fe035b41b..93240cf95 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -86,13 +86,10 @@ def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> (tmp_path / "src").mkdir() existing = tmp_path / "src" / "existing.py" existing.write_text("old\nkeep\n", encoding="utf-8") - delete_me = tmp_path / "src" / "delete_me.py" - delete_me.write_text("gone\n", encoding="utf-8") edits = [ {"path": "src/new.py", "action": "add", "new_text": "fresh"}, {"path": "src/existing.py", "action": "replace", "old_text": "old", "new_text": "new"}, - {"path": "src/delete_me.py", "action": "delete", "old_text": "gone\n"}, ] trackers = prepare_file_edit_trackers( @@ -106,18 +103,15 @@ def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> assert [tracker.display_path for tracker in trackers] == [ "src/new.py", "src/existing.py", - "src/delete_me.py", ] (tmp_path / "src" / "new.py").write_text("fresh\n", encoding="utf-8") existing.write_text("new\nkeep\n", encoding="utf-8") - delete_me.unlink() events = [build_file_edit_end_event(tracker, {"edits": edits}) for tracker in trackers] by_path = {event["path"]: event for event in events} assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0) assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1) - assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1) def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None: diff --git a/tests/utils/test_webui_sidebar_state.py b/tests/utils/test_webui_sidebar_state.py index 0244a304a..6294a0d5d 100644 --- a/tests/utils/test_webui_sidebar_state.py +++ b/tests/utils/test_webui_sidebar_state.py @@ -27,6 +27,7 @@ def test_sidebar_state_normalizes_old_or_partial_payload(tmp_path, monkeypatch) "pinned_keys": ["websocket:a", "websocket:a", "", 123], "archived_keys": ["websocket:b"], "title_overrides": {"websocket:a": " Release notes ", "bad": ""}, + "project_name_overrides": {"/repo": " Core ", "bad": ""}, "tags_by_key": {"websocket:a": ["work", "work", ""]}, "collapsed_groups": {"Earlier": 1}, "view": {"density": "tiny", "show_archived": True, "sort": "nope"}, @@ -41,6 +42,7 @@ def test_sidebar_state_normalizes_old_or_partial_payload(tmp_path, monkeypatch) assert state["pinned_keys"] == ["websocket:a"] assert state["archived_keys"] == ["websocket:b"] assert state["title_overrides"] == {"websocket:a": "Release notes"} + assert state["project_name_overrides"] == {"/repo": "Core"} assert state["tags_by_key"] == {"websocket:a": ["work"]} assert state["collapsed_groups"] == {"Earlier": True} assert state["view"] == { @@ -60,6 +62,7 @@ def test_sidebar_state_write_is_scoped_to_config_data_dir(tmp_path, monkeypatch) "pinned_keys": ["websocket:a"], "archived_keys": ["websocket:b"], "title_overrides": {"websocket:a": "Release"}, + "project_name_overrides": {"/repo": "Core"}, "view": {"density": "compact", "show_previews": True}, } ) @@ -67,6 +70,7 @@ def test_sidebar_state_write_is_scoped_to_config_data_dir(tmp_path, monkeypatch) assert state["pinned_keys"] == ["websocket:a"] assert state["archived_keys"] == ["websocket:b"] assert state["title_overrides"] == {"websocket:a": "Release"} + assert state["project_name_overrides"] == {"/repo": "Core"} assert state["view"]["density"] == "compact" assert state["view"]["show_previews"] is True assert webui_sidebar_state_path().is_file() diff --git a/tests/utils/test_webui_transcript.py b/tests/utils/test_webui_transcript.py index 6400900a8..f676c0486 100644 --- a/tests/utils/test_webui_transcript.py +++ b/tests/utils/test_webui_transcript.py @@ -122,6 +122,103 @@ def test_replay_file_edit_event_creates_file_activity(tmp_path, monkeypatch) -> assert msgs[2]["activitySegmentId"] != msgs[1]["activitySegmentId"] +def test_replay_file_edit_absorbs_matching_write_tool_event() -> None: + msgs = replay_transcript_to_ui_messages([ + { + "event": "message", + "chat_id": "t-file", + "text": 'write_file({"path":"foo.txt"})', + "kind": "tool_hint", + "tool_events": [ + { + "phase": "start", + "call_id": "call-write", + "name": "write_file", + "arguments": {"path": "foo.txt", "content": "hello\n"}, + }, + ], + }, + { + "event": "file_edit", + "chat_id": "t-file", + "edits": [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "start", + "added": 1, + "deleted": 0, + "approximate": True, + "status": "editing", + }, + ], + }, + { + "event": "message", + "chat_id": "t-file", + "text": "", + "kind": "progress", + "tool_events": [ + { + "phase": "end", + "call_id": "call-write", + "name": "write_file", + "arguments": {"path": "foo.txt", "content": "hello\n"}, + "result": "ok", + }, + ], + }, + ]) + + assert len(msgs) == 1 + assert msgs[0]["kind"] == "trace" + assert msgs[0]["traces"] == [] + assert "toolEvents" not in msgs[0] + assert msgs[0]["fileEdits"] == [ + { + "version": 1, + "call_id": "call-write", + "tool": "write_file", + "path": "foo.txt", + "phase": "start", + "added": 1, + "deleted": 0, + "approximate": True, + "status": "editing", + }, + ] + + +def test_replay_keeps_interrupted_pre_tool_text_in_activity() -> None: + msgs = replay_transcript_to_ui_messages([ + {"event": "delta", "chat_id": "t-stream", "text": "I will inspect first."}, + {"event": "stream_end", "chat_id": "t-stream"}, + { + "event": "message", + "chat_id": "t-stream", + "text": 'exec({"cmd":"ls"})', + "kind": "tool_hint", + }, + { + "event": "stream_end", + "chat_id": "t-stream", + "text": "Done. Open index.html to play.", + }, + ]) + + assert len(msgs) == 3 + assert msgs[0]["role"] == "assistant" + assert msgs[0]["content"] == "" + assert msgs[0]["reasoning"] == "I will inspect first." + assert "isStreaming" not in msgs[0] + assert msgs[1]["kind"] == "trace" + assert msgs[1]["traces"] == ['exec({"cmd":"ls"})'] + assert msgs[2]["role"] == "assistant" + assert msgs[2]["content"] == "Done. Open index.html to play." + + def test_replay_tool_events_dedupes_finish_after_start() -> None: msgs = replay_transcript_to_ui_messages([ { diff --git a/tests/utils/test_webui_workspaces.py b/tests/utils/test_webui_workspaces.py new file mode 100644 index 000000000..cf7941b6c --- /dev/null +++ b/tests/utils/test_webui_workspaces.py @@ -0,0 +1,154 @@ +import json + +from nanobot.security.workspace_access import default_workspace_scope +from nanobot.session.manager import SessionManager +from nanobot.webui.workspaces import ( + WebUIWorkspaceController, + read_webui_default_access_mode, + read_webui_workspace_state, + webui_workspace_state_path, + write_webui_default_access_mode, + workspaces_payload, +) + + +def test_workspace_state_defaults_when_file_missing(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + + state = read_webui_workspace_state() + + assert state["default_access_mode"] == "default" + assert webui_workspace_state_path() == tmp_path / "webui" / "workspace-state.json" + + +def test_workspace_state_ignores_legacy_project_history(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + project = tmp_path / "project" + project.mkdir() + path = webui_workspace_state_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps( + { + "recent_projects": [ + {"project_path": str(project)}, + {"project_path": str(tmp_path / "missing")}, + ], + "last_scope": { + "project_path": str(project), + "access_mode": "full", + }, + } + ), + encoding="utf-8", + ) + + state = read_webui_workspace_state() + + assert "recent_projects" not in state + assert "last_scope" not in state + assert state["default_access_mode"] == "default" + + +def test_workspace_payload_is_config_data_dir_scoped(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + default = tmp_path / "default" + default.mkdir() + + payload = workspaces_payload( + default_workspace=default, + default_restrict_to_workspace=False, + controls_available=True, + ) + + assert payload["default_scope"]["project_path"] == str(default.resolve()) + assert payload["default_scope"]["access_mode"] == "full" + assert payload["default_access_mode"] == "default" + assert payload["controls"]["can_change_project"] is True + + +def test_workspace_payload_hides_mutable_state_when_controls_unavailable( + tmp_path, + monkeypatch, +) -> None: + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + default = tmp_path / "default" + default.mkdir() + + payload = workspaces_payload( + default_workspace=default, + default_restrict_to_workspace=False, + controls_available=False, + ) + + assert payload["default_scope"]["project_path"] == str(default.resolve()) + assert payload["controls"]["can_change_project"] is False + assert payload["controls"]["can_use_full_access"] is False + + +def test_workspace_payload_uses_webui_default_access_mode(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + default = tmp_path / "default" + default.mkdir() + + assert write_webui_default_access_mode("full") is True + assert write_webui_default_access_mode("full") is False + + payload = workspaces_payload( + default_workspace=default, + default_restrict_to_workspace=True, + controls_available=True, + ) + + assert payload["default_access_mode"] == "full" + assert payload["default_scope"]["project_path"] == str(default.resolve()) + assert payload["default_scope"]["access_mode"] == "full" + + +def test_legacy_restricted_webui_default_access_mode_maps_to_default(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + + assert write_webui_default_access_mode("restricted") is False + assert read_webui_default_access_mode() == "default" + + +def test_webui_default_access_applies_to_unscoped_old_sessions(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + default = tmp_path / "default" + default.mkdir() + sessions = SessionManager(tmp_path / "sessions") + sessions.save(sessions.get_or_create("websocket:old-chat")) + write_webui_default_access_mode("full") + controller = WebUIWorkspaceController( + session_manager=sessions, + default_workspace=default, + default_restrict_to_workspace=True, + ) + + scope = controller.scope_for_session_key("websocket:old-chat") + new_scope = controller.scope_for_new_chat({}, controls_available=True) + + assert scope.project_path == default.resolve() + assert scope.access_mode == "full" + assert new_scope.access_mode == "full" + + +def test_webui_default_access_does_not_override_explicit_session_scope(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + default = tmp_path / "default" + project = tmp_path / "project" + default.mkdir() + project.mkdir() + sessions = SessionManager(tmp_path / "sessions") + controller = WebUIWorkspaceController( + session_manager=sessions, + default_workspace=default, + default_restrict_to_workspace=True, + ) + explicit = default_workspace_scope(project, restrict_to_workspace=False) + controller.persist_scope("explicit-chat", explicit) + + scope = controller.scope_for_session_key("websocket:explicit-chat") + + assert scope.project_path == project.resolve() + assert scope.access_mode == "full" diff --git a/tests/webui/test_settings_api.py b/tests/webui/test_settings_api.py index d1b6c175e..ce8f74789 100644 --- a/tests/webui/test_settings_api.py +++ b/tests/webui/test_settings_api.py @@ -1,10 +1,20 @@ from __future__ import annotations +import json + import pytest from nanobot.config.loader import load_config, save_config -from nanobot.config.schema import Config -from nanobot.webui.settings_api import WebUISettingsError, create_model_configuration +from nanobot.config.schema import Config, ModelPresetConfig +from nanobot.webui.settings_api import ( + WebUISettingsError, + _oauth_provider_status, + create_model_configuration, + settings_payload, + update_model_configuration, + update_network_safety_settings, +) +from nanobot.providers.registry import find_by_name def test_create_model_configuration_writes_label_and_selects( @@ -65,3 +75,237 @@ def test_create_model_configuration_rejects_unconfigured_provider( "model": ["openai/gpt-4.1"], } ) + + +def test_update_model_configuration_edits_named_preset_and_selects( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = tmp_path / "config.json" + config = Config() + config.providers.openai.api_key = "sk-test" + config.model_presets["codex"] = ModelPresetConfig( + label="Old Codex", + provider="openai", + model="openai/gpt-4.1", + ) + save_config(config, config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + monkeypatch.setattr( + "nanobot.webui.settings_api._oauth_provider_status", + lambda spec: { + "configured": spec.name == "openai_codex", + "account": "acct-test", + "expires_at": 123, + "login_supported": True, + }, + ) + + payload = update_model_configuration( + { + "name": ["codex"], + "label": ["Codex"], + "provider": ["openai_codex"], + "model": ["openai-codex/gpt-5.5"], + } + ) + + assert payload["agent"]["model_preset"] == "codex" + assert payload["agent"]["model"] == "openai-codex/gpt-5.5" + saved = load_config(config_path) + assert saved.agents.defaults.model_preset == "codex" + assert saved.model_presets["codex"].label == "Codex" + assert saved.model_presets["codex"].provider == "openai_codex" + assert saved.model_presets["codex"].model == "openai-codex/gpt-5.5" + + +def test_update_model_configuration_rejects_default_preset( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = tmp_path / "config.json" + save_config(Config(), config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + + with pytest.raises(WebUISettingsError, match="model configuration is required"): + update_model_configuration({"name": ["default"], "model": ["openai/gpt-4.1"]}) + + +def test_settings_payload_includes_oauth_provider_status( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = tmp_path / "config.json" + save_config(Config(), config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + + def fake_oauth_status(spec): + if spec.name == "openai_codex": + return { + "configured": True, + "account": "acct-test", + "expires_at": 123, + "login_supported": True, + } + return { + "configured": False, + "account": None, + "expires_at": None, + "login_supported": True, + } + + monkeypatch.setattr("nanobot.webui.settings_api._oauth_provider_status", fake_oauth_status) + + payload = settings_payload() + providers = {row["name"]: row for row in payload["providers"]} + + assert providers["openai_codex"]["auth_type"] == "oauth" + assert providers["openai_codex"]["configured"] is True + assert providers["openai_codex"]["oauth_account"] == "acct-test" + + +def test_settings_payload_includes_network_safety_fields( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = tmp_path / "config.json" + config = Config() + config.tools.webui_allow_local_service_access = False + config.tools.ssrf_whitelist = ["100.64.0.0/10"] + save_config(config, config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + + payload = settings_payload() + + assert payload["advanced"]["webui_allow_local_service_access"] is False + assert payload["advanced"]["allow_local_preview_access"] is False + assert payload["advanced"]["webui_default_access_mode"] == "default" + assert payload["advanced"]["private_service_protection_enabled"] is True + assert payload["advanced"]["ssrf_whitelist_count"] == 1 + + +def test_update_network_safety_settings_writes_local_service_flag( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = tmp_path / "config.json" + save_config(Config(), config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + + payload = update_network_safety_settings( + { + "webui_allow_local_service_access": ["false"], + "webui_default_access_mode": ["full"], + } + ) + + saved = load_config(config_path) + saved_raw = json.loads(config_path.read_text(encoding="utf-8")) + assert saved.tools.webui_allow_local_service_access is False + assert saved_raw["tools"]["webuiAllowLocalServiceAccess"] is False + assert "allowLocalPreviewAccess" not in saved_raw["tools"] + assert payload["advanced"]["webui_allow_local_service_access"] is False + assert payload["advanced"]["webui_default_access_mode"] == "full" + assert payload["requires_restart"] is True + + +def test_update_network_safety_settings_accepts_legacy_restricted_default_access( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = tmp_path / "config.json" + save_config(Config(), config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + + payload = update_network_safety_settings({"webui_default_access_mode": ["restricted"]}) + + assert payload["advanced"]["webui_default_access_mode"] == "default" + + +def test_update_network_safety_settings_default_access_is_webui_only( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = tmp_path / "config.json" + save_config(Config(), config_path) + before = config_path.read_text(encoding="utf-8") + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + monkeypatch.setattr("nanobot.webui.workspaces.get_webui_dir", lambda: tmp_path / "webui") + + payload = update_network_safety_settings({"webui_default_access_mode": ["full"]}) + + saved = load_config(config_path) + assert config_path.read_text(encoding="utf-8") == before + assert saved.tools.restrict_to_workspace is False + assert payload["advanced"]["webui_default_access_mode"] == "full" + assert payload["requires_restart"] is False + + +def test_openai_codex_oauth_status_uses_available_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def fake_get_token(): + return type( + "Token", + (), + { + "access": "access-token", + "refresh": "refresh-token", + "expires": 2_000_000_000_000, + "account_id": "acct-codex", + }, + )() + + monkeypatch.setattr("oauth_cli_kit.get_token", fake_get_token) + + status = _oauth_provider_status(find_by_name("openai_codex")) + + assert status["configured"] is True + assert status["account"] == "acct-codex" + + +def test_openai_codex_oauth_status_rejects_unavailable_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def fake_get_token(): + raise RuntimeError("refresh failed") + + monkeypatch.setattr("oauth_cli_kit.get_token", fake_get_token) + + status = _oauth_provider_status(find_by_name("openai_codex")) + + assert status["configured"] is False + assert status["account"] is None + + +def test_create_model_configuration_accepts_configured_oauth_provider( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = tmp_path / "config.json" + save_config(Config(), config_path) + monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path) + monkeypatch.setattr( + "nanobot.webui.settings_api._oauth_provider_status", + lambda spec: { + "configured": spec.name == "openai_codex", + "account": "acct-test", + "expires_at": 123, + "login_supported": True, + }, + ) + + payload = create_model_configuration( + { + "label": ["Codex"], + "provider": ["openai_codex"], + "model": ["openai-codex/gpt-5.1-codex"], + } + ) + + assert payload["agent"]["model_preset"] == "codex" + saved = load_config(config_path) + assert saved.model_presets["codex"].provider == "openai_codex" diff --git a/webui/src/App.tsx b/webui/src/App.tsx index 4ecd3aaa5..9ae092053 100644 --- a/webui/src/App.tsx +++ b/webui/src/App.tsx @@ -1,4 +1,5 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { Menu, Moon, Sun } from "lucide-react"; import { useTranslation } from "react-i18next"; import { DeleteConfirm } from "@/components/DeleteConfirm"; import { RenameChatDialog } from "@/components/RenameChatDialog"; @@ -23,9 +24,21 @@ import { import { deriveTitle } from "@/lib/format"; import { NanobotClient } from "@/lib/nanobot-client"; import { ClientProvider, useClient } from "@/providers/ClientProvider"; -import type { ChatSummary } from "@/lib/types"; +import type { + ChatSummary, + RuntimeSurface, + SettingsPayload, + WorkspaceScopePayload, + WorkspacesPayload, +} from "@/lib/types"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; +import { fetchSettings, fetchWorkspaces } from "@/lib/api"; +import { + createRuntimeHost, + toRuntimeSurface, +} from "@/lib/runtime"; +import { projectNameFromPath } from "@/lib/workspace"; type BootState = | { status: "loading" } @@ -37,6 +50,7 @@ type BootState = token: string; tokenExpiresAt: number; modelName: string | null; + runtimeSurface: RuntimeSurface; }; const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar"; @@ -149,6 +163,67 @@ function writeCompletedRunChatIds(chatIds: Set): void { } } +function normalizeWorkspaceScope(scope: WorkspaceScopePayload): WorkspaceScopePayload { + const accessMode = scope.access_mode === "restricted" ? "restricted" : "full"; + return { + ...scope, + project_name: scope.project_name ?? projectNameFromPath(scope.project_path), + access_mode: accessMode, + restrict_to_workspace: accessMode === "restricted", + }; +} + +function HostChrome({ + onToggleSidebar, + theme, + onToggleTheme, + showThemeButton = true, +}: { + onToggleSidebar?: () => void; + theme: "light" | "dark"; + onToggleTheme: () => void; + showThemeButton?: boolean; +}) { + const { t } = useTranslation(); + + return ( +
+
+ {onToggleSidebar ? ( + + ) : null} +
+ {showThemeButton ? ( + + ) : ( +
+ )} +
+ ); +} + export default function App() { const { t } = useTranslation(); const [state, setState] = useState({ status: "loading" }); @@ -163,13 +238,20 @@ export default function App() { const boot = await fetchBootstrap("", secret); if (cancelled) return; if (secret) saveSecret(secret); - const url = deriveWsUrl(boot.ws_path, boot.token); + const url = deriveWsUrl(boot.ws_path, boot.token, boot.ws_url); + const runtimeSurface = toRuntimeSurface(boot.runtime_surface); + const runtimeHost = createRuntimeHost(runtimeSurface, boot.runtime_capabilities); const client = new NanobotClient({ url, + socketFactory: runtimeHost.socketFactory, onReauth: async () => { try { const refreshed = await fetchBootstrap("", bootstrapSecretRef.current); - const refreshedUrl = deriveWsUrl(refreshed.ws_path, refreshed.token); + const refreshedUrl = deriveWsUrl( + refreshed.ws_path, + refreshed.token, + refreshed.ws_url, + ); const tokenExpiresAt = bootstrapTokenExpiresAt(refreshed.expires_in); setState((current) => current.status === "ready" && current.client === client @@ -178,6 +260,10 @@ export default function App() { token: refreshed.token, tokenExpiresAt, modelName: refreshed.model_name ?? current.modelName, + runtimeSurface: + refreshed.runtime_surface + ? toRuntimeSurface(refreshed.runtime_surface) + : current.runtimeSurface, } : current, ); @@ -195,6 +281,7 @@ export default function App() { token: boot.token, tokenExpiresAt: bootstrapTokenExpiresAt(boot.expires_in), modelName: boot.model_name ?? null, + runtimeSurface, }); } catch (e) { if (cancelled) return; @@ -219,7 +306,7 @@ export default function App() { const timer = window.setTimeout(async () => { try { const boot = await fetchBootstrap("", bootstrapSecretRef.current); - const url = deriveWsUrl(boot.ws_path, boot.token); + const url = deriveWsUrl(boot.ws_path, boot.token, boot.ws_url); const tokenExpiresAt = bootstrapTokenExpiresAt(boot.expires_in); client.updateUrl(url); setState((current) => @@ -229,6 +316,9 @@ export default function App() { token: boot.token, tokenExpiresAt, modelName: boot.model_name ?? current.modelName, + runtimeSurface: boot.runtime_surface + ? toRuntimeSurface(boot.runtime_surface) + : current.runtimeSurface, } : current, ); @@ -304,20 +394,26 @@ export default function App() { token={state.token} modelName={state.modelName} > - + ); } function Shell({ + runtimeSurface, onModelNameChange, onLogout, }: { + runtimeSurface: RuntimeSurface; onModelNameChange: (modelName: string | null) => void; onLogout: () => void; }) { const { t, i18n } = useTranslation(); - const { client } = useClient(); + const { client, token } = useClient(); const { theme, toggle } = useTheme(); const { sessions, loading, refresh, createChat, deleteChat } = useSessions(); const { state: sidebarState, update: updateSidebarState } = @@ -325,7 +421,7 @@ function Shell({ const [activeKey, setActiveKey] = useState(null); const [view, setView] = useState("chat"); const [settingsInitialSection, setSettingsInitialSection] = useState("overview"); - const [desktopSidebarOpen, setDesktopSidebarOpen] = + const [hostSidebarOpen, setHostSidebarOpen] = useState(readSidebarOpen); const [mobileSidebarOpen, setMobileSidebarOpen] = useState(false); const [sessionSearchOpen, setSessionSearchOpen] = useState(false); @@ -337,23 +433,48 @@ function Shell({ key: string; label: string; } | null>(null); + const [pendingProjectRename, setPendingProjectRename] = useState<{ + key: string; + label: string; + } | null>(null); const restartSawDisconnectRef = useRef(false); const [restartToast, setRestartToast] = useState(null); const [isRestarting, setIsRestarting] = useState(false); const [runningChatIds, setRunningChatIds] = useState>(() => new Set()); const [completedChatIds, setCompletedChatIds] = useState>(readCompletedRunChatIds); + const [workspaces, setWorkspaces] = useState(null); + const [settingsSnapshot, setSettingsSnapshot] = useState(null); + const [workspaceError, setWorkspaceError] = useState(null); + const [draftWorkspaceScope, setDraftWorkspaceScope] = + useState(null); + const [workspaceOverrides, setWorkspaceOverrides] = + useState>({}); const runningChatIdsRef = useRef>(new Set()); + useEffect(() => { + let cancelled = false; + fetchSettings(token) + .then((payload) => { + if (!cancelled) setSettingsSnapshot(payload); + }) + .catch(() => { + if (!cancelled) setSettingsSnapshot(null); + }); + return () => { + cancelled = true; + }; + }, [token]); + useEffect(() => { try { window.localStorage.setItem( SIDEBAR_STORAGE_KEY, - desktopSidebarOpen ? "1" : "0", + hostSidebarOpen ? "1" : "0", ); } catch { // ignore storage errors (private mode, etc.) } - }, [desktopSidebarOpen]); + }, [hostSidebarOpen]); useEffect(() => { writeCompletedRunChatIds(completedChatIds); @@ -365,6 +486,36 @@ function Shell({ }, [sessions, activeKey]); const runningChatIdList = useMemo(() => Array.from(runningChatIds), [runningChatIds]); const completedChatIdList = useMemo(() => Array.from(completedChatIds), [completedChatIds]); + const activeChatId = activeSession?.chatId ?? null; + const activeWorkspaceScope = useMemo(() => { + if (activeChatId && workspaceOverrides[activeChatId]) { + return workspaceOverrides[activeChatId]; + } + if (activeSession?.workspaceScope) { + return activeSession.workspaceScope; + } + return draftWorkspaceScope ?? workspaces?.default_scope ?? null; + }, [ + activeChatId, + activeSession?.workspaceScope, + draftWorkspaceScope, + workspaceOverrides, + workspaces?.default_scope, + ]); + const activeChatRunning = activeChatId ? runningChatIds.has(activeChatId) : false; + + const refreshWorkspaces = useCallback(async () => { + try { + const payload = await fetchWorkspaces(token); + setWorkspaces(payload); + } catch { + setWorkspaces(null); + } + }, [token]); + + useEffect(() => { + void refreshWorkspaces(); + }, [refreshWorkspaces]); useEffect(() => { if (loading) return; @@ -375,8 +526,34 @@ function Shell({ ); return next.size === current.size ? current : next; }); + setWorkspaceOverrides((current) => { + const entries = Object.entries(current).filter(([chatId]) => knownChatIds.has(chatId)); + return entries.length === Object.keys(current).length ? current : Object.fromEntries(entries); + }); }, [loading, sessions]); + useEffect(() => { + return client.onSessionUpdate((_chatId, _scope, workspaceScope) => { + if (!workspaceScope) return; + const next = normalizeWorkspaceScope(workspaceScope); + setWorkspaceOverrides((current) => ({ + ...current, + [_chatId]: next, + })); + setDraftWorkspaceScope(next); + setWorkspaceError(null); + void refreshWorkspaces(); + }); + }, [client, refreshWorkspaces]); + + useEffect(() => { + return client.onError((error) => { + if (error.kind !== "workspace_scope_rejected") return; + setWorkspaceError(t("errors.workspaceScopeRejected.body")); + void refreshWorkspaces(); + }); + }, [client, refreshWorkspaces, t]); + useEffect(() => { if (loading) return; const activeRunIds = sessions @@ -408,12 +585,12 @@ function Shell({ }); }, [client, loading, sessions]); - const closeDesktopSidebar = useCallback(() => { - setDesktopSidebarOpen(false); + const closeHostSidebar = useCallback(() => { + setHostSidebarOpen(false); }, []); - const openDesktopSidebar = useCallback(() => { - setDesktopSidebarOpen(true); + const openHostSidebar = useCallback(() => { + setHostSidebarOpen(true); }, []); const closeMobileSidebar = useCallback(() => { @@ -421,38 +598,88 @@ function Shell({ }, []); const toggleSidebar = useCallback(() => { - const isDesktop = + const isNativeHost = typeof window !== "undefined" && window.matchMedia("(min-width: 1024px)").matches; - if (isDesktop) { - setDesktopSidebarOpen((v) => !v); + if (isNativeHost) { + setHostSidebarOpen((v) => !v); } else { setMobileSidebarOpen((v) => !v); } }, []); - const onCreateChat = useCallback(async () => { + const applyWorkspaceScope = useCallback( + (scope: WorkspaceScopePayload) => { + const next = normalizeWorkspaceScope(scope); + setWorkspaceError(null); + if (activeChatId) { + if (!activeChatRunning) { + client.setWorkspaceScope(activeChatId, next); + } + return; + } + setDraftWorkspaceScope(next); + }, + [activeChatId, activeChatRunning, client], + ); + + const onCreateChat = useCallback(async (workspaceScope?: WorkspaceScopePayload | null) => { try { - const chatId = await createChat(); + const scope = workspaceScope ?? activeWorkspaceScope; + const chatId = await createChat(scope); setActiveKey(`websocket:${chatId}`); setView("chat"); setMobileSidebarOpen(false); + if (scope) { + setWorkspaceOverrides((current) => ({ + ...current, + [chatId]: normalizeWorkspaceScope(scope), + })); + } return chatId; } catch (e) { console.error("Failed to create chat", e); + if (e instanceof Error && e.message.startsWith("workspace_scope_rejected:")) { + setWorkspaceError(t("errors.workspaceScopeRejected.body")); + } return null; } - }, [createChat]); + }, [activeWorkspaceScope, createChat, t]); const onNewChat = useCallback(() => { setActiveKey(null); + setDraftWorkspaceScope(null); + setWorkspaceError(null); setView("chat"); setMobileSidebarOpen(false); }, []); + const onNewChatInProject = useCallback( + (projectPath: string, projectName: string) => { + const base = workspaces?.default_scope ?? activeWorkspaceScope; + const trimmed = projectPath.trim(); + if (!base || !trimmed) { + onNewChat(); + return; + } + setActiveKey(null); + setDraftWorkspaceScope(normalizeWorkspaceScope({ + project_path: trimmed, + project_name: projectName || projectNameFromPath(trimmed), + access_mode: base.access_mode, + restrict_to_workspace: base.access_mode === "restricted", + })); + setWorkspaceError(null); + setView("chat"); + setMobileSidebarOpen(false); + }, + [activeWorkspaceScope, onNewChat, workspaces?.default_scope], + ); + const onSelectChat = useCallback( (key: string) => { - const selectedChatId = sessions.find((session) => session.key === key)?.chatId; + const selected = sessions.find((session) => session.key === key); + const selectedChatId = selected?.chatId; if (selectedChatId) { setCompletedChatIds((current) => { if (!current.has(selectedChatId)) return current; @@ -461,6 +688,12 @@ function Shell({ return next; }); } + if (selected?.workspaceScope) { + setDraftWorkspaceScope(normalizeWorkspaceScope(selected.workspaceScope)); + } else { + setDraftWorkspaceScope(null); + } + setWorkspaceError(null); setActiveKey(key); setView("chat"); setMobileSidebarOpen(false); @@ -512,6 +745,61 @@ function Shell({ [pendingRename, updateSidebarState], ); + const onToggleGroup = useCallback( + (groupId: string) => { + void updateSidebarState((current) => { + const collapsedGroups = { ...current.collapsed_groups }; + if (groupId === "workspace:chats" || groupId === "date:all") { + if (collapsedGroups[groupId] === false) { + delete collapsedGroups[groupId]; + } else { + collapsedGroups[groupId] = false; + } + return { + ...current, + collapsed_groups: collapsedGroups, + }; + } + if (collapsedGroups[groupId]) { + delete collapsedGroups[groupId]; + } else { + collapsedGroups[groupId] = true; + } + return { + ...current, + collapsed_groups: collapsedGroups, + }; + }); + }, + [updateSidebarState], + ); + + const onRequestRenameProject = useCallback((key: string, label: string) => { + setPendingProjectRename({ key, label }); + }, []); + + const onConfirmProjectRename = useCallback( + (title: string) => { + if (!pendingProjectRename) return; + const key = pendingProjectRename.key; + setPendingProjectRename(null); + void updateSidebarState((current) => { + const projectNameOverrides = { ...current.project_name_overrides }; + const cleaned = title.trim(); + if (cleaned) { + projectNameOverrides[key] = cleaned; + } else { + delete projectNameOverrides[key]; + } + return { + ...current, + project_name_overrides: projectNameOverrides, + }; + }); + }, + [pendingProjectRename, updateSidebarState], + ); + const onToggleArchive = useCallback( (key: string) => { void updateSidebarState((current) => { @@ -547,19 +835,6 @@ function Shell({ })); }, [updateSidebarState]); - const onUpdateSidebarView = useCallback( - (viewUpdate: Partial) => { - void updateSidebarState((current) => ({ - ...current, - view: { - ...current.view, - ...viewUpdate, - }, - })); - }, - [updateSidebarState], - ); - const onOpenSessionSearch = useCallback(() => { setMobileSidebarOpen(false); setSessionSearchOpen(true); @@ -742,117 +1017,165 @@ function Shell({ onTogglePin, onRequestRename, onToggleArchive, + onToggleGroup, + onRequestRenameProject, + onNewChatInProject, onOpenSettings, onOpenApps, onOpenSearch: onOpenSessionSearch, activeUtility: view === "apps" ? "apps" as const : null, onToggleArchived, - onUpdateView: onUpdateSidebarView, pinnedKeys: sidebarState.pinned_keys, archivedKeys: sidebarState.archived_keys, titleOverrides: sidebarState.title_overrides, + projectNameOverrides: sidebarState.project_name_overrides, + collapsedGroups: sidebarState.collapsed_groups, runningChatIds: runningChatIdList, completedChatIds: completedChatIdList, viewState: sidebarState.view, showArchived: sidebarState.view.show_archived, archivedCount: sidebarState.archived_keys.length, + defaultWorkspacePath: workspaces?.default_scope.project_path ?? null, }; + const effectiveRuntimeSurface = + settingsSnapshot?.surface ?? settingsSnapshot?.runtime_surface ?? runtimeSurface; + const isNativeHostSetupSurface = effectiveRuntimeSurface === "native"; + const showHostChrome = isNativeHostSetupSurface; const showMainSidebar = view !== "settings"; return ( -
- {/* Desktop sidebar: in normal flow, so the thread area width stays honest. */} - {showMainSidebar ? ( - - ) : null} - - {showMainSidebar ? ( - setMobileSidebarOpen(open)} - > - - {t("sidebar.navigation")} - - - - ) : null} - - - -
-
- -
- {view !== "chat" && ( -
-
- )} -
+ {view !== "chat" && ( +
+ +
+ )} + +
setPendingRename(null)} onConfirm={onConfirmRename} /> + setPendingProjectRename(null)} + onConfirm={onConfirmProjectRename} + /> {restartToast ? (
void; onRequestRename: (key: string, label: string) => void; onToggleArchive: (key: string) => void; + onToggleGroup?: (groupId: string) => void; + onRequestRenameProject?: (projectKey: string, label: string) => void; + onNewChatInProject?: (projectPath: string, projectName: string) => void; pinnedKeys?: string[]; archivedKeys?: string[]; titleOverrides?: Record; + projectNameOverrides?: Record; + collapsedGroups?: Record; runningChatIds?: string[]; completedChatIds?: string[]; density?: SidebarDensity; @@ -46,6 +64,7 @@ interface ChatListProps { showTimestamps?: boolean; sort?: SidebarSortMode; showArchived?: boolean; + defaultWorkspacePath?: string | null; actionMenuPortalContainer?: HTMLElement | null; loading?: boolean; emptyLabel?: string; @@ -59,9 +78,14 @@ export const ChatList = memo(function ChatList({ onTogglePin, onRequestRename, onToggleArchive, + onToggleGroup, + onRequestRenameProject, + onNewChatInProject, pinnedKeys = [], archivedKeys = [], titleOverrides = {}, + projectNameOverrides = {}, + collapsedGroups = {}, runningChatIds = [], completedChatIds = [], density = "comfortable", @@ -69,19 +93,21 @@ export const ChatList = memo(function ChatList({ showTimestamps = false, sort = "updated_desc", showArchived = false, + defaultWorkspacePath, actionMenuPortalContainer, loading, emptyLabel, }: ChatListProps) { const { t } = useTranslation(); const [visibleLimit, setVisibleLimit] = useState(INITIAL_VISIBLE_SESSIONS); - const labels = useMemo(() => ({ + const labels = useMemo(() => ({ pinned: t("chat.groups.pinned"), all: t("chat.groups.all"), today: t("chat.groups.today"), yesterday: t("chat.groups.yesterday"), earlier: t("chat.groups.earlier"), archived: t("chat.groups.archived"), + projects: t("chat.groups.projects"), fallbackTitle: t("chat.newChat"), }), [t]); const groups = useMemo( @@ -89,8 +115,10 @@ export const ChatList = memo(function ChatList({ pinnedKeys, archivedKeys, titleOverrides, + projectNameOverrides, showArchived, sort, + defaultWorkspacePath, }), [ archivedKeys, @@ -100,15 +128,21 @@ export const ChatList = memo(function ChatList({ showArchived, sort, titleOverrides, + projectNameOverrides, + defaultWorkspacePath, ], ); const limitedGroups = useMemo( - () => limitGroups(groups, visibleLimit, activeKey), - [activeKey, groups, visibleLimit], + () => limitGroups(groups, visibleLimit, activeKey, collapsedGroups), + [activeKey, collapsedGroups, groups, visibleLimit], ); const totalSessionCount = useMemo( - () => groups.reduce((total, group) => total + group.sessions.length, 0), - [groups], + () => groups.reduce( + (total, group) => + total + (isCollapsedProject(group, collapsedGroups) ? 0 : group.sessions.length), + 0, + ), + [collapsedGroups, groups], ); const visibleSessionCount = useMemo( () => limitedGroups.reduce((total, group) => total + group.sessions.length, 0), @@ -143,131 +177,194 @@ export const ChatList = memo(function ChatList({ const compact = density === "compact"; return ( -
+
- {limitedGroups.map((group) => ( -
-
- {group.label} -
-
    - {group.sessions.map((s) => { - const active = s.key === activeKey; - const fallbackTitle = t("chat.fallbackTitle", { - id: s.chatId.slice(0, 6), - }); - const generatedTitle = s.title?.trim() || ""; - const title = displayTitle(s, titleOverrides, t("chat.newChat")); - const tooltipTitle = - titleOverrides[s.key]?.trim() || - generatedTitle || - deriveTitle(s.preview, fallbackTitle); - const isPinned = pinned.has(s.key); - const isArchived = archived.has(s.key); - const preview = s.preview.trim(); - const showPreview = showPreviews && preview && preview !== title; - const timestamp = showTimestamps - ? relativeTime(s.updatedAt ?? s.createdAt) - : ""; - const activityState = running.has(s.chatId) - ? "running" - : completed.has(s.chatId) - ? "complete" - : null; - return ( -
  • -
    - - - - { + const foldableChatsGroup = isFoldableChatsGroup(group); + const foldedChatsGroup = isFoldedChatsGroup(group, collapsedGroups); + const visibleSessions = visibleSessionsForGroup( + group, + activeKey, + collapsedGroups, + ); + const hiddenInGroup = Math.max(0, group.sessions.length - visibleSessions.length); + const canToggleFold = group.sessions.length > COLLAPSED_CHATS_VISIBLE_COUNT; + + return ( +
    + {group.kind === "project" + && limitedGroups[index - 1]?.kind !== "project" ? ( +
    + {labels.projects} +
    + ) : null} + {group.kind === "project" ? ( + onToggleGroup?.(group.id)} + onRequestRename={ + group.projectKey && onRequestRenameProject + ? () => onRequestRenameProject(group.projectKey ?? "", group.label) + : undefined + } + onNewChat={ + group.projectPath && onNewChatInProject + ? () => onNewChatInProject(group.projectPath ?? "", group.label) + : undefined + } + actionMenuPortalContainer={actionMenuPortalContainer} + updatedAt={showTimestamps ? group.updatedAt : null} + /> + ) : ( + + )} + {group.kind === "project" && collapsedGroups[group.id] ? null : ( +
      + {visibleSessions.map((s) => { + const active = s.key === activeKey; + const fallbackTitle = t("chat.fallbackTitle", { + id: s.chatId.slice(0, 6), + }); + const generatedTitle = s.title?.trim() || ""; + const title = displayTitle(s, titleOverrides, t("chat.newChat")); + const tooltipTitle = + titleOverrides[s.key]?.trim() || + generatedTitle || + deriveTitle(s.preview, fallbackTitle); + const isPinned = pinned.has(s.key); + const isArchived = archived.has(s.key); + const preview = s.preview.trim(); + const showPreview = showPreviews && preview && preview !== title; + const timestamp = showTimestamps + ? relativeTime(s.updatedAt ?? s.createdAt) + : ""; + const projectMode = group.kind === "project"; + const activityState = running.has(s.chatId) + ? "running" + : completed.has(s.chatId) && !active + ? "complete" + : null; + return ( +
    • +
      - - - event.preventDefault()} - > - onTogglePin(s.key)} - > - {isPinned ? ( - - ) : ( - +
      -
    • - ); - })} -
    -
    - ))} + {showPreview ? ( + + {preview} + + ) : null} + {timestamp && !projectMode ? ( + + {timestamp} + + ) : null} + + + + + + + event.preventDefault()} + > + onTogglePin(s.key)} + > + {isPinned ? ( + + ) : ( + + )} + {isPinned ? t("chat.unpin") : t("chat.pin")} + + onRequestRename(s.key, title)} + > + + {t("chat.rename")} + + onToggleArchive(s.key)} + > + {isArchived ? ( + + ) : ( + + )} + {isArchived ? t("chat.unarchive") : t("chat.archive")} + + { + window.setTimeout(() => onRequestDelete(s.key, title), 0); + }} + className="text-destructive focus:text-destructive" + > + + {t("chat.delete")} + + + +
    +
  • + ); + })} +
+ )} + {foldableChatsGroup && canToggleFold ? ( + onToggleGroup?.(group.id)} + /> + ) : null} +
+ ); + })} {hiddenSessionCount > 0 ? (
@@ -288,6 +385,133 @@ export const ChatList = memo(function ChatList({ ); }); +function ProjectGroupHeader({ + label, + path, + collapsed, + onToggle, + onRequestRename, + onNewChat, + actionMenuPortalContainer, + updatedAt, +}: { + label: string; + path?: string; + collapsed: boolean; + onToggle: () => void; + onRequestRename?: () => void; + onNewChat?: () => void; + actionMenuPortalContainer?: HTMLElement | null; + updatedAt?: string | null; +}) { + const { t } = useTranslation(); + + return ( +
+ + {updatedAt ? ( + + {relativeTime(updatedAt)} + + ) : null} + {onRequestRename ? ( + + event.stopPropagation()} + > + + + event.preventDefault()} + > + + + {t("chat.rename")} + + + + ) : null} + {onNewChat ? ( + + ) : null} +
+ ); +} + +function ChatsGroupHeader({ label }: { label: string }) { + return ( +
+ {label} +
+ ); +} + +function ChatsFoldFooter({ + folded, + hiddenCount, + onToggle, +}: { + folded: boolean; + hiddenCount: number; + onToggle: () => void; +}) { + const { t, i18n } = useTranslation(); + const collapsedFallback = i18n.resolvedLanguage?.startsWith("zh") + ? `已折叠 ${hiddenCount} 个对话` + : `${hiddenCount} hidden chats`; + + return ( +
+ +
+ ); +} + function SessionActivityIndicator({ state, }: { @@ -316,202 +540,10 @@ function SessionActivityIndicator({ title={label} className="grid h-4 w-4 shrink-0 place-items-center" > - + ); } return