mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-11 21:53:37 +00:00
Some LLM providers (OpenAI-compat, Azure, vLLM, Ollama) reject requests with consecutive same-role messages or trailing assistant messages. Add _enforce_role_alternation() to merge consecutive same-role user/assistant messages and strip trailing assistant messages before sending to the API.
129 lines
4.8 KiB
Python
129 lines
4.8 KiB
Python
"""Tests for LLMProvider._enforce_role_alternation."""
|
|
|
|
from nanobot.providers.base import LLMProvider
|
|
|
|
|
|
class TestEnforceRoleAlternation:
|
|
"""Verify trailing-assistant removal and consecutive same-role merging."""
|
|
|
|
def test_empty_messages(self):
|
|
assert LLMProvider._enforce_role_alternation([]) == []
|
|
|
|
def test_no_change_needed(self):
|
|
msgs = [
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Hello!"},
|
|
{"role": "user", "content": "Bye"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert len(result) == 4
|
|
assert result[-1]["role"] == "user"
|
|
|
|
def test_trailing_assistant_removed(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Hello!"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert len(result) == 1
|
|
assert result[0]["role"] == "user"
|
|
|
|
def test_multiple_trailing_assistants_removed(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "A"},
|
|
{"role": "assistant", "content": "B"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert len(result) == 1
|
|
assert result[0]["role"] == "user"
|
|
|
|
def test_consecutive_user_messages_merged(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "user", "content": "How are you?"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert len(result) == 1
|
|
assert "Hello" in result[0]["content"]
|
|
assert "How are you?" in result[0]["content"]
|
|
|
|
def test_consecutive_assistant_messages_merged(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Hello!"},
|
|
{"role": "assistant", "content": "How can I help?"},
|
|
{"role": "user", "content": "Thanks"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert len(result) == 3
|
|
assert "Hello!" in result[1]["content"]
|
|
assert "How can I help?" in result[1]["content"]
|
|
|
|
def test_system_messages_not_merged(self):
|
|
msgs = [
|
|
{"role": "system", "content": "System A"},
|
|
{"role": "system", "content": "System B"},
|
|
{"role": "user", "content": "Hi"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert len(result) == 3
|
|
assert result[0]["content"] == "System A"
|
|
assert result[1]["content"] == "System B"
|
|
|
|
def test_tool_messages_not_merged(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
|
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
|
{"role": "tool", "content": "result2", "tool_call_id": "2"},
|
|
{"role": "user", "content": "Next"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
tool_msgs = [m for m in result if m["role"] == "tool"]
|
|
assert len(tool_msgs) == 2
|
|
|
|
def test_non_string_content_uses_latest(self):
|
|
msgs = [
|
|
{"role": "user", "content": [{"type": "text", "text": "A"}]},
|
|
{"role": "user", "content": "B"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert len(result) == 1
|
|
assert result[0]["content"] == "B"
|
|
|
|
def test_original_messages_not_mutated(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "user", "content": "World"},
|
|
]
|
|
original_first = dict(msgs[0])
|
|
LLMProvider._enforce_role_alternation(msgs)
|
|
assert msgs[0] == original_first
|
|
assert len(msgs) == 2
|
|
|
|
def test_only_assistant_messages(self):
|
|
msgs = [
|
|
{"role": "assistant", "content": "A"},
|
|
{"role": "assistant", "content": "B"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert result == []
|
|
|
|
def test_realistic_conversation(self):
|
|
msgs = [
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
{"role": "assistant", "content": "4"},
|
|
{"role": "user", "content": "And 3+3?"},
|
|
{"role": "user", "content": "(please be quick)"},
|
|
{"role": "assistant", "content": "6"},
|
|
]
|
|
result = LLMProvider._enforce_role_alternation(msgs)
|
|
assert len(result) == 4
|
|
assert result[2]["role"] == "assistant"
|
|
assert result[3]["role"] == "user"
|
|
assert "And 3+3?" in result[3]["content"]
|
|
assert "(please be quick)" in result[3]["content"]
|