diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a210bf72d..a69a716b1 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -193,7 +193,126 @@ class OpenAICompatProvider(LLMProvider): # Response parsing # ------------------------------------------------------------------ + @staticmethod + def _maybe_mapping(value: Any) -> dict[str, Any] | None: + if isinstance(value, dict): + return value + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + return dumped + return None + + @classmethod + def _extract_text_content(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, list): + parts: list[str] = [] + for item in value: + item_map = cls._maybe_mapping(item) + if item_map: + text = item_map.get("text") + if isinstance(text, str): + parts.append(text) + continue + text = getattr(item, "text", None) + if isinstance(text, str): + parts.append(text) + continue + if isinstance(item, str): + parts.append(item) + return "".join(parts) or None + return str(value) + + @classmethod + def _extract_usage(cls, response: Any) -> dict[str, int]: + usage_obj = None + response_map = cls._maybe_mapping(response) + if response_map is not None: + usage_obj = response_map.get("usage") + elif hasattr(response, "usage") and response.usage: + usage_obj = response.usage + + usage_map = cls._maybe_mapping(usage_obj) + if usage_map is not None: + return { + "prompt_tokens": int(usage_map.get("prompt_tokens") or 0), + "completion_tokens": int(usage_map.get("completion_tokens") or 0), + "total_tokens": int(usage_map.get("total_tokens") or 0), + } + + if usage_obj: + return { + "prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage_obj, "total_tokens", 0) or 0, + } + return {} + def _parse(self, response: Any) -> LLMResponse: + if isinstance(response, str): + return LLMResponse(content=response, finish_reason="stop") + + response_map = self._maybe_mapping(response) + if response_map is not None: + choices = response_map.get("choices") or [] + if not choices: + content = self._extract_text_content( + response_map.get("content") or response_map.get("output_text") + ) + if content is not None: + return LLMResponse( + content=content, + finish_reason=str(response_map.get("finish_reason") or "stop"), + usage=self._extract_usage(response_map), + ) + return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") + + choice0 = self._maybe_mapping(choices[0]) or {} + msg0 = self._maybe_mapping(choice0.get("message")) or {} + content = self._extract_text_content(msg0.get("content")) + finish_reason = str(choice0.get("finish_reason") or "stop") + + raw_tool_calls: list[Any] = [] + reasoning_content = msg0.get("reasoning_content") + for ch in choices: + ch_map = self._maybe_mapping(ch) or {} + m = self._maybe_mapping(ch_map.get("message")) or {} + tool_calls = m.get("tool_calls") + if isinstance(tool_calls, list) and tool_calls: + raw_tool_calls.extend(tool_calls) + if ch_map.get("finish_reason") in ("tool_calls", "stop"): + finish_reason = str(ch_map["finish_reason"]) + if not content: + content = self._extract_text_content(m.get("content")) + if not reasoning_content: + reasoning_content = m.get("reasoning_content") + + parsed_tool_calls = [] + for tc in raw_tool_calls: + tc_map = self._maybe_mapping(tc) or {} + fn = self._maybe_mapping(tc_map.get("function")) or {} + args = fn.get("arguments", {}) + if isinstance(args, str): + args = json_repair.loads(args) + parsed_tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=str(fn.get("name") or ""), + arguments=args if isinstance(args, dict) else {}, + )) + + return LLMResponse( + content=content, + tool_calls=parsed_tool_calls, + finish_reason=finish_reason, + usage=self._extract_usage(response_map), + reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None, + ) + if not response.choices: return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") @@ -223,39 +342,60 @@ class OpenAICompatProvider(LLMProvider): arguments=args, )) - usage: dict[str, int] = {} - if hasattr(response, "usage") and response.usage: - u = response.usage - usage = { - "prompt_tokens": u.prompt_tokens or 0, - "completion_tokens": u.completion_tokens or 0, - "total_tokens": u.total_tokens or 0, - } - return LLMResponse( content=content, tool_calls=tool_calls, finish_reason=finish_reason or "stop", - usage=usage, + usage=self._extract_usage(response), reasoning_content=getattr(msg, "reasoning_content", None) or None, ) - @staticmethod - def _parse_chunks(chunks: list[Any]) -> LLMResponse: + @classmethod + def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse: content_parts: list[str] = [] tc_bufs: dict[int, dict[str, str]] = {} finish_reason = "stop" usage: dict[str, int] = {} for chunk in chunks: + if isinstance(chunk, str): + content_parts.append(chunk) + continue + + chunk_map = cls._maybe_mapping(chunk) + if chunk_map is not None: + choices = chunk_map.get("choices") or [] + if not choices: + usage = cls._extract_usage(chunk_map) or usage + text = cls._extract_text_content( + chunk_map.get("content") or chunk_map.get("output_text") + ) + if text: + content_parts.append(text) + continue + choice = cls._maybe_mapping(choices[0]) or {} + if choice.get("finish_reason"): + finish_reason = str(choice["finish_reason"]) + delta = cls._maybe_mapping(choice.get("delta")) or {} + text = cls._extract_text_content(delta.get("content")) + if text: + content_parts.append(text) + for idx, tc in enumerate(delta.get("tool_calls") or []): + tc_map = cls._maybe_mapping(tc) or {} + tc_index = tc_map.get("index", idx) + buf = tc_bufs.setdefault(tc_index, {"id": "", "name": "", "arguments": ""}) + if tc_map.get("id"): + buf["id"] = str(tc_map["id"]) + fn = cls._maybe_mapping(tc_map.get("function")) or {} + if fn.get("name"): + buf["name"] = str(fn["name"]) + if fn.get("arguments"): + buf["arguments"] += str(fn["arguments"]) + usage = cls._extract_usage(chunk_map) or usage + continue + if not chunk.choices: - if hasattr(chunk, "usage") and chunk.usage: - u = chunk.usage - usage = { - "prompt_tokens": u.prompt_tokens or 0, - "completion_tokens": u.completion_tokens or 0, - "total_tokens": u.total_tokens or 0, - } + usage = cls._extract_usage(chunk) or usage continue choice = chunk.choices[0] if choice.finish_reason: diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index bb46b887a..d2a9f4247 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -15,3 +15,41 @@ def test_custom_provider_parse_handles_empty_choices() -> None: assert result.finish_reason == "error" assert "empty choices" in result.content + + +def test_custom_provider_parse_accepts_plain_string_response() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse("hello from backend") + + assert result.finish_reason == "stop" + assert result.content == "hello from backend" + + +def test_custom_provider_parse_accepts_dict_response() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + result = provider._parse({ + "choices": [{ + "message": {"content": "hello from dict"}, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3, + }, + }) + + assert result.finish_reason == "stop" + assert result.content == "hello from dict" + assert result.usage["total_tokens"] == 3 + + +def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None: + result = OpenAICompatProvider._parse_chunks(["hello ", "world"]) + + assert result.finish_reason == "stop" + assert result.content == "hello world"