mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
feat(spawn): allow per-subagent sampling temperature (#3969)
This commit is contained in:
parent
ec99232208
commit
7a6cc657db
@ -140,6 +140,7 @@ class SubagentManager:
|
||||
origin_chat_id: str = "direct",
|
||||
session_key: str | None = None,
|
||||
origin_message_id: str | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> str:
|
||||
"""Spawn a subagent to execute a task in the background."""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
@ -155,7 +156,9 @@ class SubagentManager:
|
||||
self._task_statuses[task_id] = status
|
||||
|
||||
bg_task = asyncio.create_task(
|
||||
self._run_subagent(task_id, task, display_label, origin, status, origin_message_id)
|
||||
self._run_subagent(
|
||||
task_id, task, display_label, origin, status, origin_message_id, temperature
|
||||
)
|
||||
)
|
||||
self._running_tasks[task_id] = bg_task
|
||||
if session_key:
|
||||
@ -182,6 +185,7 @@ class SubagentManager:
|
||||
origin: dict[str, str],
|
||||
status: SubagentStatus,
|
||||
origin_message_id: str | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> None:
|
||||
"""Execute the subagent task and announce the result."""
|
||||
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
||||
@ -208,6 +212,7 @@ class SubagentManager:
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
temperature=temperature,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=_SubagentHook(task_id, status),
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.context import ContextAware, RequestContext
|
||||
from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.schema import NumberSchema, StringSchema, tool_parameters_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
@ -17,6 +17,15 @@ if TYPE_CHECKING:
|
||||
tool_parameters_schema(
|
||||
task=StringSchema("The task for the subagent to complete"),
|
||||
label=StringSchema("Optional short label for the task (for display)"),
|
||||
temperature=NumberSchema(
|
||||
description=(
|
||||
"Optional sampling temperature for the subagent "
|
||||
"(0.0 = deterministic, higher = more creative). "
|
||||
"Defaults to the provider's configured temperature."
|
||||
),
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
),
|
||||
required=["task"],
|
||||
)
|
||||
)
|
||||
@ -58,7 +67,13 @@ class SpawnTool(Tool, ContextAware):
|
||||
"and use a dedicated subdirectory when helpful."
|
||||
)
|
||||
|
||||
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str:
|
||||
async def execute(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
temperature: float | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Spawn a subagent to execute the given task."""
|
||||
running = self._manager.get_running_count()
|
||||
limit = self._manager.max_concurrent_subagents
|
||||
@ -75,4 +90,5 @@ class SpawnTool(Tool, ContextAware):
|
||||
origin_chat_id=self._origin_chat_id.get(),
|
||||
session_key=self._session_key.get(),
|
||||
origin_message_id=self._origin_message_id.get(),
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
@ -94,6 +94,39 @@ async def test_subagent_uses_configured_max_iterations(tmp_path):
|
||||
mgr.runner.run.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spawn_forwards_temperature_to_run_spec(tmp_path):
|
||||
"""A temperature passed to spawn() should reach the AgentRunSpec."""
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
seen = {}
|
||||
|
||||
async def fake_run(spec):
|
||||
seen["temperature"] = spec.temperature
|
||||
return SimpleNamespace(
|
||||
stop_reason="done", final_content="done", error=None, tool_events=[],
|
||||
)
|
||||
|
||||
mgr.runner.run = AsyncMock(side_effect=fake_run)
|
||||
|
||||
await mgr.spawn(task="do task", temperature=0.9)
|
||||
await asyncio.gather(*mgr._running_tasks.values(), return_exceptions=True)
|
||||
|
||||
assert seen["temperature"] == 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spawn_tool_rejects_when_at_concurrency_limit(tmp_path):
|
||||
"""SpawnTool should return an error string when the concurrency limit is reached."""
|
||||
|
||||
@ -64,6 +64,7 @@ async def test_spawn_tool_keeps_task_local_context() -> None:
|
||||
origin_chat_id: str,
|
||||
session_key: str,
|
||||
origin_message_id: str | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> str:
|
||||
seen.append((origin_channel, origin_chat_id, session_key))
|
||||
return f"{origin_channel}:{origin_chat_id}:{task}"
|
||||
@ -176,6 +177,7 @@ async def test_spawn_tool_basic_set_context_and_execute() -> None:
|
||||
origin_chat_id,
|
||||
session_key,
|
||||
origin_message_id=None,
|
||||
temperature=None,
|
||||
):
|
||||
seen.append((origin_channel, origin_chat_id, session_key))
|
||||
return f"ok: {task}"
|
||||
@ -208,6 +210,7 @@ async def test_spawn_tool_default_values_without_set_context() -> None:
|
||||
origin_chat_id,
|
||||
session_key,
|
||||
origin_message_id=None,
|
||||
temperature=None,
|
||||
):
|
||||
seen.append((origin_channel, origin_chat_id, session_key))
|
||||
return "ok"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user