mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 06:45:55 +00:00
feat: add ConversationCallback for LiteLLM tracing
This commit is contained in:
parent
5fd66cae5c
commit
9f0ce3924d
@ -212,6 +212,7 @@ class AgentLoop:
|
|||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> tuple[str | None, list[str], list[dict]]:
|
) -> tuple[str | None, list[str], list[dict]]:
|
||||||
"""Run the agent iteration loop."""
|
"""Run the agent iteration loop."""
|
||||||
messages = initial_messages
|
messages = initial_messages
|
||||||
@ -228,6 +229,7 @@ class AgentLoop:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tool_defs,
|
tools=tool_defs,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
usage = response.usage or {}
|
usage = response.usage or {}
|
||||||
self._last_usage = {
|
self._last_usage = {
|
||||||
@ -419,7 +421,9 @@ class AgentLoop:
|
|||||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
final_content, _, all_msgs = await self._run_agent_loop(
|
||||||
|
messages, metadata={"session_key": key},
|
||||||
|
)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
@ -487,6 +491,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(
|
final_content, _, all_msgs = await self._run_agent_loop(
|
||||||
initial_messages, on_progress=on_progress or _bus_progress,
|
initial_messages, on_progress=on_progress or _bus_progress,
|
||||||
|
metadata={"session_key": key},
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from nanobot.agent.tools.shell import ExecTool
|
|||||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.config.schema import ExecToolConfig
|
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.utils.helpers import build_assistant_message
|
from nanobot.utils.helpers import build_assistant_message
|
||||||
|
|
||||||
@ -29,13 +29,11 @@ class SubagentManager:
|
|||||||
workspace: Path,
|
workspace: Path,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
web_search_config: "WebSearchConfig | None" = None,
|
web_search_config: WebSearchConfig | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
|
||||||
|
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
@ -113,6 +111,14 @@ class SubagentManager:
|
|||||||
{"role": "user", "content": task},
|
{"role": "user", "content": task},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Build metadata for tracing
|
||||||
|
session_key = f"{origin['channel']}:{origin['chat_id']}"
|
||||||
|
metadata = {
|
||||||
|
"session_key": session_key,
|
||||||
|
"agent_type": "subagent",
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
# Run agent loop (limited iterations)
|
# Run agent loop (limited iterations)
|
||||||
max_iterations = 15
|
max_iterations = 15
|
||||||
iteration = 0
|
iteration = 0
|
||||||
@ -125,6 +131,7 @@ class SubagentManager:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools.get_definitions(),
|
tools=tools.get_definitions(),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
|
|||||||
@ -442,6 +442,7 @@ def _make_provider(config: Config):
|
|||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
from nanobot.providers.callbacks import ConversationCallback
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
spec = find_by_name(provider_name)
|
spec = find_by_name(provider_name)
|
||||||
@ -449,12 +450,23 @@ def _make_provider(config: Config):
|
|||||||
console.print("[red]Error: No API key configured.[/red]")
|
console.print("[red]Error: No API key configured.[/red]")
|
||||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Create callback if tracing is enabled
|
||||||
|
callback = None
|
||||||
|
if config.tracing.enabled:
|
||||||
|
jsonl_path = config.tracing.jsonl_path
|
||||||
|
# Resolve relative paths relative to workspace
|
||||||
|
if jsonl_path:
|
||||||
|
jsonl_path = str(config.workspace_path / jsonl_path)
|
||||||
|
callback = ConversationCallback(jsonl_path=jsonl_path)
|
||||||
|
|
||||||
provider = LiteLLMProvider(
|
provider = LiteLLMProvider(
|
||||||
api_key=p.api_key if p else None,
|
api_key=p.api_key if p else None,
|
||||||
api_base=config.get_api_base(model),
|
api_base=config.get_api_base(model),
|
||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=p.extra_headers if p else None,
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
|
conversation_callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
defaults = config.agents.defaults
|
defaults = config.agents.defaults
|
||||||
|
|||||||
@ -143,6 +143,17 @@ class ToolsConfig(Base):
|
|||||||
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TracingConfig(Base):
|
||||||
|
"""Configuration for LLM call tracing."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
jsonl_path: str | None = None # e.g., "logs/traces.jsonl"
|
||||||
|
|
||||||
|
# LangSmith integration (built-in LiteLLM support)
|
||||||
|
langsmith_enabled: bool = False # Override env var detection
|
||||||
|
langsmith_project: str = "nanobot"
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseSettings):
|
class Config(BaseSettings):
|
||||||
"""Root configuration for nanobot."""
|
"""Root configuration for nanobot."""
|
||||||
|
|
||||||
@ -151,6 +162,7 @@ class Config(BaseSettings):
|
|||||||
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
||||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||||
|
tracing: TracingConfig = Field(default_factory=TracingConfig)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workspace_path(self) -> Path:
|
def workspace_path(self) -> Path:
|
||||||
|
|||||||
@ -232,6 +232,7 @@ class LLMProvider(ABC):
|
|||||||
temperature: object = _SENTINEL,
|
temperature: object = _SENTINEL,
|
||||||
reasoning_effort: object = _SENTINEL,
|
reasoning_effort: object = _SENTINEL,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Call chat() with retry on transient provider failures.
|
"""Call chat() with retry on transient provider failures.
|
||||||
|
|
||||||
@ -250,6 +251,7 @@ class LLMProvider(ABC):
|
|||||||
messages=messages, tools=tools, model=model,
|
messages=messages, tools=tools, model=model,
|
||||||
max_tokens=max_tokens, temperature=temperature,
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||||
|
|||||||
191
nanobot/providers/callbacks.py
Normal file
191
nanobot/providers/callbacks.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
"""LiteLLM callbacks for conversation tracing.
|
||||||
|
|
||||||
|
This module provides a non-invasive way to capture complete LLM conversation
|
||||||
|
traces, including agent and subagent trajectories, without modifying core
|
||||||
|
agent loop logic.
|
||||||
|
|
||||||
|
The callback receives kwargs["messages"] which contains the FULL conversation
|
||||||
|
history sent to the LLM - this is exactly what we need for trajectory persistence.
|
||||||
|
|
||||||
|
Reference: https://docs.litellm.ai/docs/observability/custom_callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_duration_ms(start_time: Any, end_time: Any) -> int:
|
||||||
|
"""Calculate duration in milliseconds, handling both float and datetime inputs."""
|
||||||
|
try:
|
||||||
|
diff = end_time - start_time
|
||||||
|
if isinstance(diff, timedelta):
|
||||||
|
return int(diff.total_seconds() * 1000)
|
||||||
|
return int(diff * 1000)
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationCallback(CustomLogger):
|
||||||
|
"""LiteLLM callback for tracking full conversation traces.
|
||||||
|
|
||||||
|
Captures complete LLM call context including:
|
||||||
|
- kwargs["messages"]: Full conversation history (user + assistant + tool)
|
||||||
|
- Response content, model, usage stats
|
||||||
|
- Session metadata via kwargs["litellm_params"]["metadata"]
|
||||||
|
|
||||||
|
This enables trajectory persistence without invasive modifications to
|
||||||
|
agent loop or subagent code.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
callback = ConversationCallback(jsonl_path=Path("traces.jsonl"))
|
||||||
|
|
||||||
|
# Register with LiteLLM (in LiteLLMProvider.__init__):
|
||||||
|
litellm.callbacks = [callback]
|
||||||
|
|
||||||
|
# Or pass via kwargs (current pattern in LiteLLMProvider.chat):
|
||||||
|
kwargs.setdefault("callbacks", []).append(callback)
|
||||||
|
|
||||||
|
Metadata Fields (passed via litellm_params.metadata):
|
||||||
|
- session_key: Session identifier (e.g., "cli:direct", "subagent:abc123")
|
||||||
|
- agent_type: "main" or "subagent"
|
||||||
|
- parent_session: For subagents, the parent session key
|
||||||
|
- task_id: For subagents, the task identifier
|
||||||
|
- turn_count: Current turn number in the conversation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, jsonl_path: str | Path | None = None):
|
||||||
|
super().__init__()
|
||||||
|
self.jsonl_path = Path(jsonl_path) if jsonl_path else None
|
||||||
|
self._write_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def async_log_success_event(
|
||||||
|
self,
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
response_obj: Any,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> None:
|
||||||
|
"""Called by LiteLLM after each successful LLM completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Contains model, messages, litellm_params (with metadata), etc.
|
||||||
|
response_obj: The completion response object
|
||||||
|
start_time: Unix timestamp of call start
|
||||||
|
end_time: Unix timestamp of call end
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract core data
|
||||||
|
messages = kwargs.get("messages", [])
|
||||||
|
model = kwargs.get("model", "unknown")
|
||||||
|
|
||||||
|
# Extract metadata (session correlation)
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
metadata = litellm_params.get("metadata", {})
|
||||||
|
|
||||||
|
# Extract response content
|
||||||
|
response_content = ""
|
||||||
|
finish_reason = ""
|
||||||
|
if hasattr(response_obj, "choices") and response_obj.choices:
|
||||||
|
choice = response_obj.choices[0]
|
||||||
|
msg = choice.message
|
||||||
|
response_content = getattr(msg, "content", "") or ""
|
||||||
|
finish_reason = getattr(choice, "finish_reason", "") or ""
|
||||||
|
|
||||||
|
# Extract usage stats
|
||||||
|
usage = {}
|
||||||
|
if hasattr(response_obj, "usage") and response_obj.usage:
|
||||||
|
u = response_obj.usage
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": getattr(u, "prompt_tokens", 0),
|
||||||
|
"completion_tokens": getattr(u, "completion_tokens", 0),
|
||||||
|
"total_tokens": getattr(u, "total_tokens", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract cost and cache info (LiteLLM calculates these)
|
||||||
|
cost = kwargs.get("response_cost", 0)
|
||||||
|
cache_hit = kwargs.get("cache_hit", False)
|
||||||
|
|
||||||
|
# Build trace entry
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"model": model,
|
||||||
|
"messages": messages, # FULL conversation history
|
||||||
|
"response": response_content,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"usage": usage,
|
||||||
|
"cost": cost,
|
||||||
|
"cache_hit": cache_hit,
|
||||||
|
"metadata": metadata,
|
||||||
|
"start_time": start_time.isoformat() if isinstance(start_time, datetime) else start_time,
|
||||||
|
"end_time": end_time.isoformat() if isinstance(end_time, datetime) else end_time,
|
||||||
|
"duration_ms": _calc_duration_ms(start_time, end_time),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Write to JSONL (protected by lock for concurrent safety)
|
||||||
|
if self.jsonl_path:
|
||||||
|
self.jsonl_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
async with self._write_lock:
|
||||||
|
with open(self.jsonl_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"ConversationCallback: session={}, model={}, messages={}, cost={}",
|
||||||
|
metadata.get("session_key", "unknown"),
|
||||||
|
model,
|
||||||
|
len(messages),
|
||||||
|
cost
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("ConversationCallback error: {}", e)
|
||||||
|
|
||||||
|
async def async_log_failure_event(
|
||||||
|
self,
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
response_obj: Any,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> None:
|
||||||
|
"""Called by LiteLLM when a completion fails."""
|
||||||
|
try:
|
||||||
|
model = kwargs.get("model", "unknown")
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
metadata = litellm_params.get("metadata", {})
|
||||||
|
exception = kwargs.get("exception", None)
|
||||||
|
|
||||||
|
error_entry = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"event_type": "failure",
|
||||||
|
"model": model,
|
||||||
|
"metadata": metadata,
|
||||||
|
"error": str(exception) if exception else "Unknown error",
|
||||||
|
"error_type": type(exception).__name__ if exception else "Unknown",
|
||||||
|
"start_time": start_time.isoformat() if isinstance(start_time, datetime) else start_time,
|
||||||
|
"end_time": end_time.isoformat() if isinstance(end_time, datetime) else end_time,
|
||||||
|
"duration_ms": _calc_duration_ms(start_time, end_time),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.jsonl_path:
|
||||||
|
self.jsonl_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
async with self._write_lock:
|
||||||
|
with open(self.jsonl_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(error_entry, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"ConversationCallback failure: session={}, model={}, error={}",
|
||||||
|
metadata.get("session_key", "unknown"),
|
||||||
|
model,
|
||||||
|
exception
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("ConversationCallback failure logging error: {}", e)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ConversationCallback(jsonl_path={self.jsonl_path})"
|
||||||
@ -12,6 +12,7 @@ from litellm import acompletion
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.providers.callbacks import ConversationCallback
|
||||||
from nanobot.providers.registry import find_by_model, find_gateway
|
from nanobot.providers.registry import find_by_model, find_gateway
|
||||||
|
|
||||||
# Standard chat-completion message keys.
|
# Standard chat-completion message keys.
|
||||||
@ -40,10 +41,12 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
default_model: str = "anthropic/claude-opus-4-5",
|
default_model: str = "anthropic/claude-opus-4-5",
|
||||||
extra_headers: dict[str, str] | None = None,
|
extra_headers: dict[str, str] | None = None,
|
||||||
provider_name: str | None = None,
|
provider_name: str | None = None,
|
||||||
|
conversation_callback: ConversationCallback | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(api_key, api_base)
|
super().__init__(api_key, api_base)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
self.extra_headers = extra_headers or {}
|
self.extra_headers = extra_headers or {}
|
||||||
|
self._conversation_callback = conversation_callback
|
||||||
|
|
||||||
# Detect gateway / local deployment.
|
# Detect gateway / local deployment.
|
||||||
# provider_name (from config key) is the primary signal;
|
# provider_name (from config key) is the primary signal;
|
||||||
@ -232,6 +235,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request via LiteLLM.
|
Send a chat completion request via LiteLLM.
|
||||||
@ -270,8 +274,18 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||||
self._apply_model_overrides(model, kwargs)
|
self._apply_model_overrides(model, kwargs)
|
||||||
|
|
||||||
|
# Callbacks: custom conversation callback + LangSmith
|
||||||
|
callbacks_list = []
|
||||||
|
if self._conversation_callback:
|
||||||
|
callbacks_list.append(self._conversation_callback)
|
||||||
if self._langsmith_enabled:
|
if self._langsmith_enabled:
|
||||||
kwargs.setdefault("callbacks", []).append("langsmith")
|
callbacks_list.append("langsmith")
|
||||||
|
if callbacks_list:
|
||||||
|
kwargs["callbacks"] = callbacks_list
|
||||||
|
|
||||||
|
# Pass metadata for callback correlation (session_key, agent_type, etc.)
|
||||||
|
if metadata:
|
||||||
|
kwargs["metadata"] = metadata
|
||||||
|
|
||||||
# Pass api_key directly — more reliable than env vars alone
|
# Pass api_key directly — more reliable than env vars alone
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
|
|||||||
152
tests/providers/test_callbacks.py
Normal file
152
tests/providers/test_callbacks.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# tests/providers/test_callbacks.py
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_module_exists():
|
||||||
|
from nanobot.providers.callbacks import ConversationCallback
|
||||||
|
assert ConversationCallback is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_inherits_custom_logger():
|
||||||
|
"""Verify callback follows LiteLLM's CustomLogger pattern."""
|
||||||
|
from nanobot.providers.callbacks import ConversationCallback
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
cb = ConversationCallback()
|
||||||
|
assert isinstance(cb, CustomLogger)
|
||||||
|
assert hasattr(cb, "async_log_success_event")
|
||||||
|
assert hasattr(cb, "async_log_failure_event")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_log_success_event_extracts_full_messages():
|
||||||
|
"""Verify callback captures full message history from kwargs["messages"]."""
|
||||||
|
from nanobot.providers.callbacks import ConversationCallback
|
||||||
|
|
||||||
|
cb = ConversationCallback()
|
||||||
|
|
||||||
|
# Simulate a subagent conversation with tool calls
|
||||||
|
kwargs = {
|
||||||
|
"model": "anthropic/claude-opus-4-5",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a subagent..."},
|
||||||
|
{"role": "user", "content": "Read test.txt and summarize it"},
|
||||||
|
{"role": "assistant", "content": "I'll read the file.", "tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "read_file", "arguments": '{"path": "test.txt"}'}}
|
||||||
|
]},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": "Hello World"},
|
||||||
|
{"role": "assistant", "content": "The file contains: Hello World"},
|
||||||
|
],
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"session_key": "subagent:abc12345",
|
||||||
|
"agent_type": "subagent",
|
||||||
|
"parent_session": "cli:direct",
|
||||||
|
"task_id": "abc12345",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"response_cost": 0.0025,
|
||||||
|
"cache_hit": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
response_obj = MagicMock()
|
||||||
|
response_obj.choices = [MagicMock()]
|
||||||
|
response_obj.choices[0].message.content = "The file contains: Hello World"
|
||||||
|
response_obj.choices[0].finish_reason = "stop"
|
||||||
|
response_obj.usage = MagicMock()
|
||||||
|
response_obj.usage.prompt_tokens = 150
|
||||||
|
response_obj.usage.completion_tokens = 20
|
||||||
|
response_obj.usage.total_tokens = 170
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await cb.async_log_success_event(kwargs, response_obj, 0.0, 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_writes_to_jsonl(tmp_path):
|
||||||
|
"""Verify callback writes trace to JSONL file."""
|
||||||
|
from nanobot.providers.callbacks import ConversationCallback
|
||||||
|
|
||||||
|
jsonl_path = tmp_path / "traces.jsonl"
|
||||||
|
cb = ConversationCallback(jsonl_path=jsonl_path)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"litellm_params": {"metadata": {"session_key": "test:123"}},
|
||||||
|
"response_cost": 0.001,
|
||||||
|
}
|
||||||
|
|
||||||
|
response_obj = MagicMock()
|
||||||
|
response_obj.choices = [MagicMock()]
|
||||||
|
response_obj.choices[0].message.content = "hello"
|
||||||
|
response_obj.choices[0].finish_reason = "stop"
|
||||||
|
response_obj.usage = MagicMock()
|
||||||
|
response_obj.usage.prompt_tokens = 5
|
||||||
|
response_obj.usage.completion_tokens = 5
|
||||||
|
response_obj.usage.total_tokens = 10
|
||||||
|
|
||||||
|
await cb.async_log_success_event(kwargs, response_obj, 0.0, 0.1)
|
||||||
|
|
||||||
|
# Verify file was written
|
||||||
|
assert jsonl_path.exists()
|
||||||
|
import json
|
||||||
|
with open(jsonl_path) as f:
|
||||||
|
entry = json.loads(f.readline())
|
||||||
|
assert entry["model"] == "test-model"
|
||||||
|
assert len(entry["messages"]) == 1
|
||||||
|
assert entry["metadata"]["session_key"] == "test:123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_log_failure_event_extracts_fields():
|
||||||
|
"""Verify async_log_failure_event captures error details from kwargs."""
|
||||||
|
from nanobot.providers.callbacks import ConversationCallback
|
||||||
|
|
||||||
|
cb = ConversationCallback()
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": "anthropic/claude-opus-4-5",
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"session_key": "cli:direct",
|
||||||
|
"agent_type": "main",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"exception": ValueError("Rate limit exceeded"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await cb.async_log_failure_event(kwargs, None, 0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_log_failure_event_writes_to_jsonl(tmp_path):
|
||||||
|
"""Verify failure event is written to JSONL with correct structure."""
|
||||||
|
from nanobot.providers.callbacks import ConversationCallback
|
||||||
|
|
||||||
|
jsonl_path = tmp_path / "traces.jsonl"
|
||||||
|
cb = ConversationCallback(jsonl_path=jsonl_path)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": "test-model",
|
||||||
|
"litellm_params": {"metadata": {"session_key": "test:456"}},
|
||||||
|
"exception": ValueError("API error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
await cb.async_log_failure_event(kwargs, None, 0.0, 0.5)
|
||||||
|
|
||||||
|
# Verify file was written
|
||||||
|
assert jsonl_path.exists()
|
||||||
|
import json
|
||||||
|
with open(jsonl_path) as f:
|
||||||
|
entry = json.loads(f.readline())
|
||||||
|
|
||||||
|
assert entry["event_type"] == "failure"
|
||||||
|
assert entry["model"] == "test-model"
|
||||||
|
assert entry["metadata"]["session_key"] == "test:456"
|
||||||
|
assert entry["error"] == "API error"
|
||||||
|
assert entry["error_type"] == "ValueError"
|
||||||
|
assert "duration_ms" in entry
|
||||||
Loading…
x
Reference in New Issue
Block a user