mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-11 13:43:37 +00:00
192 lines
7.4 KiB
Python
192 lines
7.4 KiB
Python
"""LiteLLM callbacks for conversation tracing.
|
|
|
|
This module provides a non-invasive way to capture complete LLM conversation
|
|
traces, including agent and subagent trajectories, without modifying core
|
|
agent loop logic.
|
|
|
|
The callback receives kwargs["messages"] which contains the FULL conversation
|
|
history sent to the LLM - this is exactly what we need for trajectory persistence.
|
|
|
|
Reference: https://docs.litellm.ai/docs/observability/custom_callback
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from loguru import logger
|
|
|
|
|
|
def _calc_duration_ms(start_time: Any, end_time: Any) -> int:
|
|
"""Calculate duration in milliseconds, handling both float and datetime inputs."""
|
|
try:
|
|
diff = end_time - start_time
|
|
if isinstance(diff, timedelta):
|
|
return int(diff.total_seconds() * 1000)
|
|
return int(diff * 1000)
|
|
except (TypeError, AttributeError):
|
|
return 0
|
|
|
|
|
|
class ConversationCallback(CustomLogger):
|
|
"""LiteLLM callback for tracking full conversation traces.
|
|
|
|
Captures complete LLM call context including:
|
|
- kwargs["messages"]: Full conversation history (user + assistant + tool)
|
|
- Response content, model, usage stats
|
|
- Session metadata via kwargs["litellm_params"]["metadata"]
|
|
|
|
This enables trajectory persistence without invasive modifications to
|
|
agent loop or subagent code.
|
|
|
|
Usage:
|
|
callback = ConversationCallback(jsonl_path=Path("traces.jsonl"))
|
|
|
|
# Register with LiteLLM (in LiteLLMProvider.__init__):
|
|
litellm.callbacks = [callback]
|
|
|
|
# Or pass via kwargs (current pattern in LiteLLMProvider.chat):
|
|
kwargs.setdefault("callbacks", []).append(callback)
|
|
|
|
Metadata Fields (passed via litellm_params.metadata):
|
|
- session_key: Session identifier (e.g., "cli:direct", "subagent:abc123")
|
|
- agent_type: "main" or "subagent"
|
|
- parent_session: For subagents, the parent session key
|
|
- task_id: For subagents, the task identifier
|
|
- turn_count: Current turn number in the conversation
|
|
"""
|
|
|
|
def __init__(self, jsonl_path: str | Path | None = None):
|
|
super().__init__()
|
|
self.jsonl_path = Path(jsonl_path) if jsonl_path else None
|
|
self._write_lock = asyncio.Lock()
|
|
|
|
async def async_log_success_event(
|
|
self,
|
|
kwargs: dict[str, Any],
|
|
response_obj: Any,
|
|
start_time: float,
|
|
end_time: float,
|
|
) -> None:
|
|
"""Called by LiteLLM after each successful LLM completion.
|
|
|
|
Args:
|
|
kwargs: Contains model, messages, litellm_params (with metadata), etc.
|
|
response_obj: The completion response object
|
|
start_time: Unix timestamp of call start
|
|
end_time: Unix timestamp of call end
|
|
"""
|
|
try:
|
|
# Extract core data
|
|
messages = kwargs.get("messages", [])
|
|
model = kwargs.get("model", "unknown")
|
|
|
|
# Extract metadata (session correlation)
|
|
litellm_params = kwargs.get("litellm_params", {})
|
|
metadata = litellm_params.get("metadata", {})
|
|
|
|
# Extract response content
|
|
response_content = ""
|
|
finish_reason = ""
|
|
if hasattr(response_obj, "choices") and response_obj.choices:
|
|
choice = response_obj.choices[0]
|
|
msg = choice.message
|
|
response_content = getattr(msg, "content", "") or ""
|
|
finish_reason = getattr(choice, "finish_reason", "") or ""
|
|
|
|
# Extract usage stats
|
|
usage = {}
|
|
if hasattr(response_obj, "usage") and response_obj.usage:
|
|
u = response_obj.usage
|
|
usage = {
|
|
"prompt_tokens": getattr(u, "prompt_tokens", 0),
|
|
"completion_tokens": getattr(u, "completion_tokens", 0),
|
|
"total_tokens": getattr(u, "total_tokens", 0),
|
|
}
|
|
|
|
# Extract cost and cache info (LiteLLM calculates these)
|
|
cost = kwargs.get("response_cost", 0)
|
|
cache_hit = kwargs.get("cache_hit", False)
|
|
|
|
# Build trace entry
|
|
entry = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"model": model,
|
|
"messages": messages, # FULL conversation history
|
|
"response": response_content,
|
|
"finish_reason": finish_reason,
|
|
"usage": usage,
|
|
"cost": cost,
|
|
"cache_hit": cache_hit,
|
|
"metadata": metadata,
|
|
"start_time": start_time.isoformat() if isinstance(start_time, datetime) else start_time,
|
|
"end_time": end_time.isoformat() if isinstance(end_time, datetime) else end_time,
|
|
"duration_ms": _calc_duration_ms(start_time, end_time),
|
|
}
|
|
|
|
# Write to JSONL (protected by lock for concurrent safety)
|
|
if self.jsonl_path:
|
|
self.jsonl_path.parent.mkdir(parents=True, exist_ok=True)
|
|
async with self._write_lock:
|
|
with open(self.jsonl_path, "a", encoding="utf-8") as f:
|
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
|
|
logger.debug(
|
|
"ConversationCallback: session={}, model={}, messages={}, cost={}",
|
|
metadata.get("session_key", "unknown"),
|
|
model,
|
|
len(messages),
|
|
cost
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning("ConversationCallback error: {}", e)
|
|
|
|
async def async_log_failure_event(
|
|
self,
|
|
kwargs: dict[str, Any],
|
|
response_obj: Any,
|
|
start_time: float,
|
|
end_time: float,
|
|
) -> None:
|
|
"""Called by LiteLLM when a completion fails."""
|
|
try:
|
|
model = kwargs.get("model", "unknown")
|
|
litellm_params = kwargs.get("litellm_params", {})
|
|
metadata = litellm_params.get("metadata", {})
|
|
exception = kwargs.get("exception", None)
|
|
|
|
error_entry = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"event_type": "failure",
|
|
"model": model,
|
|
"metadata": metadata,
|
|
"error": str(exception) if exception else "Unknown error",
|
|
"error_type": type(exception).__name__ if exception else "Unknown",
|
|
"start_time": start_time.isoformat() if isinstance(start_time, datetime) else start_time,
|
|
"end_time": end_time.isoformat() if isinstance(end_time, datetime) else end_time,
|
|
"duration_ms": _calc_duration_ms(start_time, end_time),
|
|
}
|
|
|
|
if self.jsonl_path:
|
|
self.jsonl_path.parent.mkdir(parents=True, exist_ok=True)
|
|
async with self._write_lock:
|
|
with open(self.jsonl_path, "a", encoding="utf-8") as f:
|
|
f.write(json.dumps(error_entry, ensure_ascii=False) + "\n")
|
|
|
|
logger.warning(
|
|
"ConversationCallback failure: session={}, model={}, error={}",
|
|
metadata.get("session_key", "unknown"),
|
|
model,
|
|
exception
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("ConversationCallback failure logging error: {}", e)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"ConversationCallback(jsonl_path={self.jsonl_path})"
|