mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-11 13:43:37 +00:00
fix(provider): accept plain text OpenAI-compatible responses
Handle string and dict-shaped responses from OpenAI-compatible backends so non-standard providers no longer crash on missing choices fields. Add regression tests to keep SDK, dict, and plain-text parsing paths aligned.
This commit is contained in:
parent
321214e2e0
commit
263069583d
@ -193,7 +193,126 @@ class OpenAICompatProvider(LLMProvider):
|
||||
# Response parsing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _maybe_mapping(value: Any) -> dict[str, Any] | None:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
model_dump = getattr(value, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
dumped = model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return dumped
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_text_content(cls, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
parts: list[str] = []
|
||||
for item in value:
|
||||
item_map = cls._maybe_mapping(item)
|
||||
if item_map:
|
||||
text = item_map.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
continue
|
||||
text = getattr(item, "text", None)
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
continue
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
return "".join(parts) or None
|
||||
return str(value)
|
||||
|
||||
@classmethod
|
||||
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
||||
usage_obj = None
|
||||
response_map = cls._maybe_mapping(response)
|
||||
if response_map is not None:
|
||||
usage_obj = response_map.get("usage")
|
||||
elif hasattr(response, "usage") and response.usage:
|
||||
usage_obj = response.usage
|
||||
|
||||
usage_map = cls._maybe_mapping(usage_obj)
|
||||
if usage_map is not None:
|
||||
return {
|
||||
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
||||
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
||||
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
||||
}
|
||||
|
||||
if usage_obj:
|
||||
return {
|
||||
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
||||
}
|
||||
return {}
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
if isinstance(response, str):
|
||||
return LLMResponse(content=response, finish_reason="stop")
|
||||
|
||||
response_map = self._maybe_mapping(response)
|
||||
if response_map is not None:
|
||||
choices = response_map.get("choices") or []
|
||||
if not choices:
|
||||
content = self._extract_text_content(
|
||||
response_map.get("content") or response_map.get("output_text")
|
||||
)
|
||||
if content is not None:
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
||||
usage=self._extract_usage(response_map),
|
||||
)
|
||||
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||||
|
||||
choice0 = self._maybe_mapping(choices[0]) or {}
|
||||
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
||||
content = self._extract_text_content(msg0.get("content"))
|
||||
finish_reason = str(choice0.get("finish_reason") or "stop")
|
||||
|
||||
raw_tool_calls: list[Any] = []
|
||||
reasoning_content = msg0.get("reasoning_content")
|
||||
for ch in choices:
|
||||
ch_map = self._maybe_mapping(ch) or {}
|
||||
m = self._maybe_mapping(ch_map.get("message")) or {}
|
||||
tool_calls = m.get("tool_calls")
|
||||
if isinstance(tool_calls, list) and tool_calls:
|
||||
raw_tool_calls.extend(tool_calls)
|
||||
if ch_map.get("finish_reason") in ("tool_calls", "stop"):
|
||||
finish_reason = str(ch_map["finish_reason"])
|
||||
if not content:
|
||||
content = self._extract_text_content(m.get("content"))
|
||||
if not reasoning_content:
|
||||
reasoning_content = m.get("reasoning_content")
|
||||
|
||||
parsed_tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
tc_map = self._maybe_mapping(tc) or {}
|
||||
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
||||
args = fn.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
parsed_tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=str(fn.get("name") or ""),
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
))
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=self._extract_usage(response_map),
|
||||
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||||
|
||||
@ -223,39 +342,60 @@ class OpenAICompatProvider(LLMProvider):
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
usage: dict[str, int] = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
u = response.usage
|
||||
usage = {
|
||||
"prompt_tokens": u.prompt_tokens or 0,
|
||||
"completion_tokens": u.completion_tokens or 0,
|
||||
"total_tokens": u.total_tokens or 0,
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason or "stop",
|
||||
usage=usage,
|
||||
usage=self._extract_usage(response),
|
||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_chunks(chunks: list[Any]) -> LLMResponse:
|
||||
@classmethod
|
||||
def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse:
|
||||
content_parts: list[str] = []
|
||||
tc_bufs: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
usage: dict[str, int] = {}
|
||||
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk, str):
|
||||
content_parts.append(chunk)
|
||||
continue
|
||||
|
||||
chunk_map = cls._maybe_mapping(chunk)
|
||||
if chunk_map is not None:
|
||||
choices = chunk_map.get("choices") or []
|
||||
if not choices:
|
||||
usage = cls._extract_usage(chunk_map) or usage
|
||||
text = cls._extract_text_content(
|
||||
chunk_map.get("content") or chunk_map.get("output_text")
|
||||
)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
continue
|
||||
choice = cls._maybe_mapping(choices[0]) or {}
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = str(choice["finish_reason"])
|
||||
delta = cls._maybe_mapping(choice.get("delta")) or {}
|
||||
text = cls._extract_text_content(delta.get("content"))
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
for idx, tc in enumerate(delta.get("tool_calls") or []):
|
||||
tc_map = cls._maybe_mapping(tc) or {}
|
||||
tc_index = tc_map.get("index", idx)
|
||||
buf = tc_bufs.setdefault(tc_index, {"id": "", "name": "", "arguments": ""})
|
||||
if tc_map.get("id"):
|
||||
buf["id"] = str(tc_map["id"])
|
||||
fn = cls._maybe_mapping(tc_map.get("function")) or {}
|
||||
if fn.get("name"):
|
||||
buf["name"] = str(fn["name"])
|
||||
if fn.get("arguments"):
|
||||
buf["arguments"] += str(fn["arguments"])
|
||||
usage = cls._extract_usage(chunk_map) or usage
|
||||
continue
|
||||
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
u = chunk.usage
|
||||
usage = {
|
||||
"prompt_tokens": u.prompt_tokens or 0,
|
||||
"completion_tokens": u.completion_tokens or 0,
|
||||
"total_tokens": u.total_tokens or 0,
|
||||
}
|
||||
usage = cls._extract_usage(chunk) or usage
|
||||
continue
|
||||
choice = chunk.choices[0]
|
||||
if choice.finish_reason:
|
||||
|
||||
@ -15,3 +15,41 @@ def test_custom_provider_parse_handles_empty_choices() -> None:
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
assert "empty choices" in result.content
|
||||
|
||||
|
||||
def test_custom_provider_parse_accepts_plain_string_response() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
result = provider._parse("hello from backend")
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "hello from backend"
|
||||
|
||||
|
||||
def test_custom_provider_parse_accepts_dict_response() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
result = provider._parse({
|
||||
"choices": [{
|
||||
"message": {"content": "hello from dict"},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
})
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "hello from dict"
|
||||
assert result.usage["total_tokens"] == 3
|
||||
|
||||
|
||||
def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
|
||||
result = OpenAICompatProvider._parse_chunks(["hello ", "world"])
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "hello world"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user