fix(providers): enforce role alternation for non-Claude providers

Some LLM providers (OpenAI-compat, Azure, vLLM, Ollama) reject requests
with consecutive same-role messages or trailing assistant messages. Add
_enforce_role_alternation() to merge consecutive same-role user/assistant
messages and strip trailing assistant messages before sending to the API.
This commit is contained in:
Ziyan Lin 2026-03-30 15:15:15 +08:00
parent c8c520cc9a
commit 26ae906116
4 changed files with 170 additions and 4 deletions

View File

@ -94,9 +94,11 @@ class AzureOpenAIProvider(LLMProvider):
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
"messages": self._enforce_role_alternation(
self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
)
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
}

View File

@ -196,6 +196,42 @@ class LLMProvider(ABC):
err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
@staticmethod
def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive same-role messages and drop trailing assistant messages.
Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests
where the last message is 'assistant' (prefill not supported) or two
consecutive non-system messages share the same role.
"""
if not messages:
return messages
merged: list[dict[str, Any]] = []
for msg in messages:
role = msg.get("role")
if (
merged
and role != "system"
and role not in ("tool",)
and merged[-1].get("role") == role
and role in ("user", "assistant")
):
prev = merged[-1]
prev_content = prev.get("content") or ""
curr_content = msg.get("content") or ""
if isinstance(prev_content, str) and isinstance(curr_content, str):
prev["content"] = (prev_content + "\n\n" + curr_content).strip()
else:
merged[-1] = dict(msg)
else:
merged.append(dict(msg))
while merged and merged[-1].get("role") == "assistant":
merged.pop()
return merged
@staticmethod
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url blocks with text placeholder. Returns None if no images found."""

View File

@ -215,7 +215,7 @@ class OpenAICompatProvider(LLMProvider):
clean["tool_calls"] = normalized
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
return self._enforce_role_alternation(sanitized)
# ------------------------------------------------------------------
# Build kwargs

View File

@ -0,0 +1,128 @@
"""Tests for LLMProvider._enforce_role_alternation."""
from nanobot.providers.base import LLMProvider
class TestEnforceRoleAlternation:
"""Verify trailing-assistant removal and consecutive same-role merging."""
def test_empty_messages(self):
assert LLMProvider._enforce_role_alternation([]) == []
def test_no_change_needed(self):
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "Bye"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 4
assert result[-1]["role"] == "user"
def test_trailing_assistant_removed(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_multiple_trailing_assistants_removed(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "A"},
{"role": "assistant", "content": "B"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_consecutive_user_messages_merged(self):
msgs = [
{"role": "user", "content": "Hello"},
{"role": "user", "content": "How are you?"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 1
assert "Hello" in result[0]["content"]
assert "How are you?" in result[0]["content"]
def test_consecutive_assistant_messages_merged(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "assistant", "content": "How can I help?"},
{"role": "user", "content": "Thanks"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 3
assert "Hello!" in result[1]["content"]
assert "How can I help?" in result[1]["content"]
def test_system_messages_not_merged(self):
msgs = [
{"role": "system", "content": "System A"},
{"role": "system", "content": "System B"},
{"role": "user", "content": "Hi"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 3
assert result[0]["content"] == "System A"
assert result[1]["content"] == "System B"
def test_tool_messages_not_merged(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
{"role": "tool", "content": "result1", "tool_call_id": "1"},
{"role": "tool", "content": "result2", "tool_call_id": "2"},
{"role": "user", "content": "Next"},
]
result = LLMProvider._enforce_role_alternation(msgs)
tool_msgs = [m for m in result if m["role"] == "tool"]
assert len(tool_msgs) == 2
def test_non_string_content_uses_latest(self):
msgs = [
{"role": "user", "content": [{"type": "text", "text": "A"}]},
{"role": "user", "content": "B"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 1
assert result[0]["content"] == "B"
def test_original_messages_not_mutated(self):
msgs = [
{"role": "user", "content": "Hello"},
{"role": "user", "content": "World"},
]
original_first = dict(msgs[0])
LLMProvider._enforce_role_alternation(msgs)
assert msgs[0] == original_first
assert len(msgs) == 2
def test_only_assistant_messages(self):
msgs = [
{"role": "assistant", "content": "A"},
{"role": "assistant", "content": "B"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert result == []
def test_realistic_conversation(self):
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "And 3+3?"},
{"role": "user", "content": "(please be quick)"},
{"role": "assistant", "content": "6"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 4
assert result[2]["role"] == "assistant"
assert result[3]["role"] == "user"
assert "And 3+3?" in result[3]["content"]
assert "(please be quick)" in result[3]["content"]