mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
feat(spawn): allow per-subagent sampling temperature (#3969)
This commit is contained in:
parent
ec99232208
commit
7a6cc657db
@ -140,6 +140,7 @@ class SubagentManager:
|
|||||||
origin_chat_id: str = "direct",
|
origin_chat_id: str = "direct",
|
||||||
session_key: str | None = None,
|
session_key: str | None = None,
|
||||||
origin_message_id: str | None = None,
|
origin_message_id: str | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Spawn a subagent to execute a task in the background."""
|
"""Spawn a subagent to execute a task in the background."""
|
||||||
task_id = str(uuid.uuid4())[:8]
|
task_id = str(uuid.uuid4())[:8]
|
||||||
@ -155,7 +156,9 @@ class SubagentManager:
|
|||||||
self._task_statuses[task_id] = status
|
self._task_statuses[task_id] = status
|
||||||
|
|
||||||
bg_task = asyncio.create_task(
|
bg_task = asyncio.create_task(
|
||||||
self._run_subagent(task_id, task, display_label, origin, status, origin_message_id)
|
self._run_subagent(
|
||||||
|
task_id, task, display_label, origin, status, origin_message_id, temperature
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self._running_tasks[task_id] = bg_task
|
self._running_tasks[task_id] = bg_task
|
||||||
if session_key:
|
if session_key:
|
||||||
@ -182,6 +185,7 @@ class SubagentManager:
|
|||||||
origin: dict[str, str],
|
origin: dict[str, str],
|
||||||
status: SubagentStatus,
|
status: SubagentStatus,
|
||||||
origin_message_id: str | None = None,
|
origin_message_id: str | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute the subagent task and announce the result."""
|
"""Execute the subagent task and announce the result."""
|
||||||
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
||||||
@ -208,6 +212,7 @@ class SubagentManager:
|
|||||||
initial_messages=messages,
|
initial_messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
temperature=temperature,
|
||||||
max_iterations=self.max_iterations,
|
max_iterations=self.max_iterations,
|
||||||
max_tool_result_chars=self.max_tool_result_chars,
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
hook=_SubagentHook(task_id, status),
|
hook=_SubagentHook(task_id, status),
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
from nanobot.agent.tools.context import ContextAware, RequestContext
|
from nanobot.agent.tools.context import ContextAware, RequestContext
|
||||||
from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema
|
from nanobot.agent.tools.schema import NumberSchema, StringSchema, tool_parameters_schema
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
@ -17,6 +17,15 @@ if TYPE_CHECKING:
|
|||||||
tool_parameters_schema(
|
tool_parameters_schema(
|
||||||
task=StringSchema("The task for the subagent to complete"),
|
task=StringSchema("The task for the subagent to complete"),
|
||||||
label=StringSchema("Optional short label for the task (for display)"),
|
label=StringSchema("Optional short label for the task (for display)"),
|
||||||
|
temperature=NumberSchema(
|
||||||
|
description=(
|
||||||
|
"Optional sampling temperature for the subagent "
|
||||||
|
"(0.0 = deterministic, higher = more creative). "
|
||||||
|
"Defaults to the provider's configured temperature."
|
||||||
|
),
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=2.0,
|
||||||
|
),
|
||||||
required=["task"],
|
required=["task"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -58,7 +67,13 @@ class SpawnTool(Tool, ContextAware):
|
|||||||
"and use a dedicated subdirectory when helpful."
|
"and use a dedicated subdirectory when helpful."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str:
|
async def execute(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
label: str | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
"""Spawn a subagent to execute the given task."""
|
"""Spawn a subagent to execute the given task."""
|
||||||
running = self._manager.get_running_count()
|
running = self._manager.get_running_count()
|
||||||
limit = self._manager.max_concurrent_subagents
|
limit = self._manager.max_concurrent_subagents
|
||||||
@ -75,4 +90,5 @@ class SpawnTool(Tool, ContextAware):
|
|||||||
origin_chat_id=self._origin_chat_id.get(),
|
origin_chat_id=self._origin_chat_id.get(),
|
||||||
session_key=self._session_key.get(),
|
session_key=self._session_key.get(),
|
||||||
origin_message_id=self._origin_message_id.get(),
|
origin_message_id=self._origin_message_id.get(),
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -94,6 +94,39 @@ async def test_subagent_uses_configured_max_iterations(tmp_path):
|
|||||||
mgr.runner.run.assert_awaited_once()
|
mgr.runner.run.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_spawn_forwards_temperature_to_run_spec(tmp_path):
|
||||||
|
"""A temperature passed to spawn() should reach the AgentRunSpec."""
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
|
seen = {}
|
||||||
|
|
||||||
|
async def fake_run(spec):
|
||||||
|
seen["temperature"] = spec.temperature
|
||||||
|
return SimpleNamespace(
|
||||||
|
stop_reason="done", final_content="done", error=None, tool_events=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
mgr.runner.run = AsyncMock(side_effect=fake_run)
|
||||||
|
|
||||||
|
await mgr.spawn(task="do task", temperature=0.9)
|
||||||
|
await asyncio.gather(*mgr._running_tasks.values(), return_exceptions=True)
|
||||||
|
|
||||||
|
assert seen["temperature"] == 0.9
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_spawn_tool_rejects_when_at_concurrency_limit(tmp_path):
|
async def test_spawn_tool_rejects_when_at_concurrency_limit(tmp_path):
|
||||||
"""SpawnTool should return an error string when the concurrency limit is reached."""
|
"""SpawnTool should return an error string when the concurrency limit is reached."""
|
||||||
|
|||||||
@ -64,6 +64,7 @@ async def test_spawn_tool_keeps_task_local_context() -> None:
|
|||||||
origin_chat_id: str,
|
origin_chat_id: str,
|
||||||
session_key: str,
|
session_key: str,
|
||||||
origin_message_id: str | None = None,
|
origin_message_id: str | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
seen.append((origin_channel, origin_chat_id, session_key))
|
seen.append((origin_channel, origin_chat_id, session_key))
|
||||||
return f"{origin_channel}:{origin_chat_id}:{task}"
|
return f"{origin_channel}:{origin_chat_id}:{task}"
|
||||||
@ -176,6 +177,7 @@ async def test_spawn_tool_basic_set_context_and_execute() -> None:
|
|||||||
origin_chat_id,
|
origin_chat_id,
|
||||||
session_key,
|
session_key,
|
||||||
origin_message_id=None,
|
origin_message_id=None,
|
||||||
|
temperature=None,
|
||||||
):
|
):
|
||||||
seen.append((origin_channel, origin_chat_id, session_key))
|
seen.append((origin_channel, origin_chat_id, session_key))
|
||||||
return f"ok: {task}"
|
return f"ok: {task}"
|
||||||
@ -208,6 +210,7 @@ async def test_spawn_tool_default_values_without_set_context() -> None:
|
|||||||
origin_chat_id,
|
origin_chat_id,
|
||||||
session_key,
|
session_key,
|
||||||
origin_message_id=None,
|
origin_message_id=None,
|
||||||
|
temperature=None,
|
||||||
):
|
):
|
||||||
seen.append((origin_channel, origin_chat_id, session_key))
|
seen.append((origin_channel, origin_chat_id, session_key))
|
||||||
return "ok"
|
return "ok"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user