feat(spawn): allow per-subagent sampling temperature (#3969)

This commit is contained in:
04cb 2026-05-23 22:51:56 +08:00 committed by Xubin Ren
parent ec99232208
commit 7a6cc657db
4 changed files with 60 additions and 3 deletions

View File

@ -140,6 +140,7 @@ class SubagentManager:
origin_chat_id: str = "direct", origin_chat_id: str = "direct",
session_key: str | None = None, session_key: str | None = None,
origin_message_id: str | None = None, origin_message_id: str | None = None,
temperature: float | None = None,
) -> str: ) -> str:
"""Spawn a subagent to execute a task in the background.""" """Spawn a subagent to execute a task in the background."""
task_id = str(uuid.uuid4())[:8] task_id = str(uuid.uuid4())[:8]
@ -155,7 +156,9 @@ class SubagentManager:
self._task_statuses[task_id] = status self._task_statuses[task_id] = status
bg_task = asyncio.create_task( bg_task = asyncio.create_task(
self._run_subagent(task_id, task, display_label, origin, status, origin_message_id) self._run_subagent(
task_id, task, display_label, origin, status, origin_message_id, temperature
)
) )
self._running_tasks[task_id] = bg_task self._running_tasks[task_id] = bg_task
if session_key: if session_key:
@ -182,6 +185,7 @@ class SubagentManager:
origin: dict[str, str], origin: dict[str, str],
status: SubagentStatus, status: SubagentStatus,
origin_message_id: str | None = None, origin_message_id: str | None = None,
temperature: float | None = None,
) -> None: ) -> None:
"""Execute the subagent task and announce the result.""" """Execute the subagent task and announce the result."""
logger.info("Subagent [{}] starting task: {}", task_id, label) logger.info("Subagent [{}] starting task: {}", task_id, label)
@ -208,6 +212,7 @@ class SubagentManager:
initial_messages=messages, initial_messages=messages,
tools=tools, tools=tools,
model=self.model, model=self.model,
temperature=temperature,
max_iterations=self.max_iterations, max_iterations=self.max_iterations,
max_tool_result_chars=self.max_tool_result_chars, max_tool_result_chars=self.max_tool_result_chars,
hook=_SubagentHook(task_id, status), hook=_SubagentHook(task_id, status),

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.context import ContextAware, RequestContext from nanobot.agent.tools.context import ContextAware, RequestContext
from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema from nanobot.agent.tools.schema import NumberSchema, StringSchema, tool_parameters_schema
if TYPE_CHECKING: if TYPE_CHECKING:
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
@ -17,6 +17,15 @@ if TYPE_CHECKING:
tool_parameters_schema( tool_parameters_schema(
task=StringSchema("The task for the subagent to complete"), task=StringSchema("The task for the subagent to complete"),
label=StringSchema("Optional short label for the task (for display)"), label=StringSchema("Optional short label for the task (for display)"),
temperature=NumberSchema(
description=(
"Optional sampling temperature for the subagent "
"(0.0 = deterministic, higher = more creative). "
"Defaults to the provider's configured temperature."
),
minimum=0.0,
maximum=2.0,
),
required=["task"], required=["task"],
) )
) )
@ -58,7 +67,13 @@ class SpawnTool(Tool, ContextAware):
"and use a dedicated subdirectory when helpful." "and use a dedicated subdirectory when helpful."
) )
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str: async def execute(
self,
task: str,
label: str | None = None,
temperature: float | None = None,
**kwargs: Any,
) -> str:
"""Spawn a subagent to execute the given task.""" """Spawn a subagent to execute the given task."""
running = self._manager.get_running_count() running = self._manager.get_running_count()
limit = self._manager.max_concurrent_subagents limit = self._manager.max_concurrent_subagents
@ -75,4 +90,5 @@ class SpawnTool(Tool, ContextAware):
origin_chat_id=self._origin_chat_id.get(), origin_chat_id=self._origin_chat_id.get(),
session_key=self._session_key.get(), session_key=self._session_key.get(),
origin_message_id=self._origin_message_id.get(), origin_message_id=self._origin_message_id.get(),
temperature=temperature,
) )

View File

@ -94,6 +94,39 @@ async def test_subagent_uses_configured_max_iterations(tmp_path):
mgr.runner.run.assert_awaited_once() mgr.runner.run.assert_awaited_once()
@pytest.mark.asyncio
async def test_spawn_forwards_temperature_to_run_spec(tmp_path):
"""A temperature passed to spawn() should reach the AgentRunSpec."""
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock()
seen = {}
async def fake_run(spec):
seen["temperature"] = spec.temperature
return SimpleNamespace(
stop_reason="done", final_content="done", error=None, tool_events=[],
)
mgr.runner.run = AsyncMock(side_effect=fake_run)
await mgr.spawn(task="do task", temperature=0.9)
await asyncio.gather(*mgr._running_tasks.values(), return_exceptions=True)
assert seen["temperature"] == 0.9
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_spawn_tool_rejects_when_at_concurrency_limit(tmp_path): async def test_spawn_tool_rejects_when_at_concurrency_limit(tmp_path):
"""SpawnTool should return an error string when the concurrency limit is reached.""" """SpawnTool should return an error string when the concurrency limit is reached."""

View File

@ -64,6 +64,7 @@ async def test_spawn_tool_keeps_task_local_context() -> None:
origin_chat_id: str, origin_chat_id: str,
session_key: str, session_key: str,
origin_message_id: str | None = None, origin_message_id: str | None = None,
temperature: float | None = None,
) -> str: ) -> str:
seen.append((origin_channel, origin_chat_id, session_key)) seen.append((origin_channel, origin_chat_id, session_key))
return f"{origin_channel}:{origin_chat_id}:{task}" return f"{origin_channel}:{origin_chat_id}:{task}"
@ -176,6 +177,7 @@ async def test_spawn_tool_basic_set_context_and_execute() -> None:
origin_chat_id, origin_chat_id,
session_key, session_key,
origin_message_id=None, origin_message_id=None,
temperature=None,
): ):
seen.append((origin_channel, origin_chat_id, session_key)) seen.append((origin_channel, origin_chat_id, session_key))
return f"ok: {task}" return f"ok: {task}"
@ -208,6 +210,7 @@ async def test_spawn_tool_default_values_without_set_context() -> None:
origin_chat_id, origin_chat_id,
session_key, session_key,
origin_message_id=None, origin_message_id=None,
temperature=None,
): ):
seen.append((origin_channel, origin_chat_id, session_key)) seen.append((origin_channel, origin_chat_id, session_key))
return "ok" return "ok"