mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 21:23:39 +00:00
Merge PR #2637: fix(providers): enforce role alternation for non-Claude providers
fix(providers): enforce role alternation for non-Claude providers
This commit is contained in:
commit
3361ac9dd1
@ -353,6 +353,42 @@ class LLMProvider(ABC):
|
||||
# Unknown 429 defaults to WAIT+retry.
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Merge consecutive same-role messages and drop trailing assistant messages.
|
||||
|
||||
Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests
|
||||
where the last message is 'assistant' (prefill not supported) or two
|
||||
consecutive non-system messages share the same role.
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
merged: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
if (
|
||||
merged
|
||||
and role != "system"
|
||||
and role not in ("tool",)
|
||||
and merged[-1].get("role") == role
|
||||
and role in ("user", "assistant")
|
||||
):
|
||||
prev = merged[-1]
|
||||
prev_content = prev.get("content") or ""
|
||||
curr_content = msg.get("content") or ""
|
||||
if isinstance(prev_content, str) and isinstance(curr_content, str):
|
||||
prev["content"] = (prev_content + "\n\n" + curr_content).strip()
|
||||
else:
|
||||
merged[-1] = dict(msg)
|
||||
else:
|
||||
merged.append(dict(msg))
|
||||
|
||||
while merged and merged[-1].get("role") == "assistant":
|
||||
merged.pop()
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
||||
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
||||
|
||||
@ -245,7 +245,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
clean["tool_calls"] = normalized
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return sanitized
|
||||
return self._enforce_role_alternation(sanitized)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build kwargs
|
||||
|
||||
@ -184,17 +184,22 @@ def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
"extra_content": GEMINI_EXTRA,
|
||||
}],
|
||||
}]
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
"extra_content": GEMINI_EXTRA,
|
||||
}],
|
||||
},
|
||||
{"role": "tool", "content": "ok", "tool_call_id": "call_1"},
|
||||
{"role": "user", "content": "thanks"},
|
||||
]
|
||||
|
||||
sanitized = provider._sanitize_messages(messages)
|
||||
|
||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
|
||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
|
||||
|
||||
128
tests/providers/test_enforce_role_alternation.py
Normal file
128
tests/providers/test_enforce_role_alternation.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""Tests for LLMProvider._enforce_role_alternation."""
|
||||
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class TestEnforceRoleAlternation:
|
||||
"""Verify trailing-assistant removal and consecutive same-role merging."""
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert LLMProvider._enforce_role_alternation([]) == []
|
||||
|
||||
def test_no_change_needed(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "user", "content": "Bye"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 4
|
||||
assert result[-1]["role"] == "user"
|
||||
|
||||
def test_trailing_assistant_removed(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_multiple_trailing_assistants_removed(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_consecutive_user_messages_merged(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 1
|
||||
assert "Hello" in result[0]["content"]
|
||||
assert "How are you?" in result[0]["content"]
|
||||
|
||||
def test_consecutive_assistant_messages_merged(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "How can I help?"},
|
||||
{"role": "user", "content": "Thanks"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 3
|
||||
assert "Hello!" in result[1]["content"]
|
||||
assert "How can I help?" in result[1]["content"]
|
||||
|
||||
def test_system_messages_not_merged(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "System A"},
|
||||
{"role": "system", "content": "System B"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 3
|
||||
assert result[0]["content"] == "System A"
|
||||
assert result[1]["content"] == "System B"
|
||||
|
||||
def test_tool_messages_not_merged(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||
{"role": "tool", "content": "result2", "tool_call_id": "2"},
|
||||
{"role": "user", "content": "Next"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
tool_msgs = [m for m in result if m["role"] == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
|
||||
def test_non_string_content_uses_latest(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "A"}]},
|
||||
{"role": "user", "content": "B"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "B"
|
||||
|
||||
def test_original_messages_not_mutated(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "user", "content": "World"},
|
||||
]
|
||||
original_first = dict(msgs[0])
|
||||
LLMProvider._enforce_role_alternation(msgs)
|
||||
assert msgs[0] == original_first
|
||||
assert len(msgs) == 2
|
||||
|
||||
def test_only_assistant_messages(self):
|
||||
msgs = [
|
||||
{"role": "assistant", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert result == []
|
||||
|
||||
def test_realistic_conversation(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
{"role": "user", "content": "And 3+3?"},
|
||||
{"role": "user", "content": "(please be quick)"},
|
||||
{"role": "assistant", "content": "6"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 4
|
||||
assert result[2]["role"] == "assistant"
|
||||
assert result[3]["role"] == "user"
|
||||
assert "And 3+3?" in result[3]["content"]
|
||||
assert "(please be quick)" in result[3]["content"]
|
||||
@ -532,6 +532,7 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "done",
|
||||
@ -545,12 +546,13 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
||||
"extra_content": {"google": {"thought_signature": "sig"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
{"role": "user", "content": "thanks"},
|
||||
])
|
||||
|
||||
assert sanitized[0]["reasoning_content"] == "hidden"
|
||||
assert sanitized[0]["extra_content"] == {"debug": True}
|
||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
assert sanitized[1]["reasoning_content"] == "hidden"
|
||||
assert sanitized[1]["extra_content"] == {"debug": True}
|
||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user