diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index d5833c9ae..275b1ea08 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -353,6 +353,42 @@ class LLMProvider(ABC): # Unknown 429 defaults to WAIT+retry. return True + @staticmethod + def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Merge consecutive same-role messages and drop trailing assistant messages. + + Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests + where the last message is 'assistant' (prefill not supported) or two + consecutive non-system messages share the same role. + """ + if not messages: + return messages + + merged: list[dict[str, Any]] = [] + for msg in messages: + role = msg.get("role") + if ( + merged + and role != "system" + and role not in ("tool",) + and merged[-1].get("role") == role + and role in ("user", "assistant") + ): + prev = merged[-1] + prev_content = prev.get("content") or "" + curr_content = msg.get("content") or "" + if isinstance(prev_content, str) and isinstance(curr_content, str): + prev["content"] = (prev_content + "\n\n" + curr_content).strip() + else: + merged[-1] = dict(msg) + else: + merged.append(dict(msg)) + + while merged and merged[-1].get("role") == "assistant": + merged.pop() + + return merged + @staticmethod def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: """Replace image_url blocks with text placeholder. Returns None if no images found.""" diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index aaa170395..95e8b74d3 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -245,7 +245,7 @@ class OpenAICompatProvider(LLMProvider): clean["tool_calls"] = normalized if "tool_call_id" in clean and clean["tool_call_id"]: clean["tool_call_id"] = map_id(clean["tool_call_id"]) - return sanitized + return self._enforce_role_alternation(sanitized) # ------------------------------------------------------------------ # Build kwargs diff --git a/tests/agent/test_gemini_thought_signature.py b/tests/agent/test_gemini_thought_signature.py index 320c1ecd2..3928ce11f 100644 --- a/tests/agent/test_gemini_thought_signature.py +++ b/tests/agent/test_gemini_thought_signature.py @@ -184,17 +184,22 @@ def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None: with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): provider = OpenAICompatProvider() - messages = [{ - "role": "assistant", - "content": None, - "tool_calls": [{ - "id": "call_1", - "type": "function", - "function": {"name": "fn", "arguments": "{}"}, - "extra_content": GEMINI_EXTRA, - }], - }] + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + "extra_content": GEMINI_EXTRA, + }], + }, + {"role": "tool", "content": "ok", "tool_call_id": "call_1"}, + {"role": "user", "content": "thanks"}, + ] sanitized = provider._sanitize_messages(messages) - assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA + assert sanitized[1]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA diff --git a/tests/providers/test_enforce_role_alternation.py b/tests/providers/test_enforce_role_alternation.py new file mode 100644 index 000000000..1fade6e4b --- /dev/null +++ b/tests/providers/test_enforce_role_alternation.py @@ -0,0 +1,128 @@ +"""Tests for LLMProvider._enforce_role_alternation.""" + +from nanobot.providers.base import LLMProvider + + +class TestEnforceRoleAlternation: + """Verify trailing-assistant removal and consecutive same-role merging.""" + + def test_empty_messages(self): + assert LLMProvider._enforce_role_alternation([]) == [] + + def test_no_change_needed(self): + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "Bye"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 4 + assert result[-1]["role"] == "user" + + def test_trailing_assistant_removed(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_multiple_trailing_assistants_removed(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "A"}, + {"role": "assistant", "content": "B"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_consecutive_user_messages_merged(self): + msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "How are you?"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 1 + assert "Hello" in result[0]["content"] + assert "How are you?" in result[0]["content"] + + def test_consecutive_assistant_messages_merged(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "assistant", "content": "How can I help?"}, + {"role": "user", "content": "Thanks"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 3 + assert "Hello!" in result[1]["content"] + assert "How can I help?" in result[1]["content"] + + def test_system_messages_not_merged(self): + msgs = [ + {"role": "system", "content": "System A"}, + {"role": "system", "content": "System B"}, + {"role": "user", "content": "Hi"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 3 + assert result[0]["content"] == "System A" + assert result[1]["content"] == "System B" + + def test_tool_messages_not_merged(self): + msgs = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]}, + {"role": "tool", "content": "result1", "tool_call_id": "1"}, + {"role": "tool", "content": "result2", "tool_call_id": "2"}, + {"role": "user", "content": "Next"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + tool_msgs = [m for m in result if m["role"] == "tool"] + assert len(tool_msgs) == 2 + + def test_non_string_content_uses_latest(self): + msgs = [ + {"role": "user", "content": [{"type": "text", "text": "A"}]}, + {"role": "user", "content": "B"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 1 + assert result[0]["content"] == "B" + + def test_original_messages_not_mutated(self): + msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "World"}, + ] + original_first = dict(msgs[0]) + LLMProvider._enforce_role_alternation(msgs) + assert msgs[0] == original_first + assert len(msgs) == 2 + + def test_only_assistant_messages(self): + msgs = [ + {"role": "assistant", "content": "A"}, + {"role": "assistant", "content": "B"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert result == [] + + def test_realistic_conversation(self): + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + {"role": "user", "content": "And 3+3?"}, + {"role": "user", "content": "(please be quick)"}, + {"role": "assistant", "content": "6"}, + ] + result = LLMProvider._enforce_role_alternation(msgs) + assert len(result) == 4 + assert result[2]["role"] == "assistant" + assert result[3]["role"] == "user" + assert "And 3+3?" in result[3]["content"] + assert "(please be quick)" in result[3]["content"] diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 8839ea3f0..dfb7cd228 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -532,6 +532,7 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None: provider = OpenAICompatProvider() sanitized = provider._sanitize_messages([ + {"role": "user", "content": "hi"}, { "role": "assistant", "content": "done", @@ -545,12 +546,13 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None: "extra_content": {"google": {"thought_signature": "sig"}}, } ], - } + }, + {"role": "user", "content": "thanks"}, ]) - assert sanitized[0]["reasoning_content"] == "hidden" - assert sanitized[0]["extra_content"] == {"debug": True} - assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}} + assert sanitized[1]["reasoning_content"] == "hidden" + assert sanitized[1]["extra_content"] == {"debug": True} + assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}} @pytest.mark.asyncio