fix(multimodal): image OOM guard, Feishu post media extraction, vision fallback

- Add file size pre-check via stat() before read_bytes() to prevent OOM
  on oversized images/audio/video
- Fix _extract_post_content to extract media tags (file_key) from Feishu
  post messages so videos are no longer silently dropped
- Add supports_vision=False guard to downgrade images to text placeholders
- Add video_mime_compat() for video format validation
- Use full file path in content_text so model read_file works if needed
- Pass input_limits to AgentLoop in nanobot.py facade
- Deduplicate _MEDIA_PLACEHOLDER_TYPES from LLMProvider constant
- Remove unused _extract_post_text legacy wrapper
- Add 14 new tests covering vision fallback, count limits, video compat
This commit is contained in:
chengyongru 2026-04-09 01:13:40 +08:00
parent b9346b0d59
commit c053f9eba8
8 changed files with 298 additions and 44 deletions

View File

@ -2,6 +2,7 @@
import base64
import mimetypes
import os
import platform
from pathlib import Path
from typing import Any
@ -16,6 +17,7 @@ from nanobot.utils.helpers import (
current_time_str,
detect_audio_mime,
detect_image_mime,
video_mime_compat,
)
from nanobot.utils.prompt_templates import render_template
@ -33,6 +35,18 @@ class ContextBuilder:
self.skills = SkillsLoader(workspace)
self.input_limits = input_limits or InputLimitsConfig()
@staticmethod
def _file_size_ok(p: Path, max_bytes: int) -> bool | None:
"""Check file size via stat without reading into memory.
Returns True if size is within limit, False if oversized,
None if file cannot be stat'd (caller should try read_bytes instead).
"""
try:
return os.stat(p).st_size <= max_bytes
except OSError:
return None
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
parts = [self._get_identity()]
@ -171,7 +185,8 @@ class ContextBuilder:
text: The user text message.
media: List of file paths to media files.
supports_vision: True=model supports images, False=use placeholder,
None=unconfigured (send images as before).
None=unconfigured (send images as before, let
provider/retry handle degradation).
supports_audio: True=model supports native audio, False/None=skip
(channel layer already transcribed).
supports_video: True=model supports native video, False/None=use
@ -210,15 +225,22 @@ class ContextBuilder:
# Process images
for path in image_media:
p = Path(path)
# When explicitly marked as non-vision, downgrade to text placeholder
if supports_vision is False:
blocks.append({"type": "text", "text": f"[image: {p}]"})
continue
size_ok = self._file_size_ok(p, limits.max_input_image_bytes)
if size_ok is False:
size_mb = limits.max_input_image_bytes // (1024 * 1024)
notes.append(f"[Skipped image: file too large ({p.name}, limit {size_mb} MB)]")
continue
try:
raw = p.read_bytes()
except OSError:
notes.append(f"[Skipped image: unable to read ({p.name or path})]")
continue
if len(raw) > limits.max_input_image_bytes:
size_mb = limits.max_input_image_bytes // (1024 * 1024)
notes.append(f"[Skipped image: file too large ({p.name}, limit {size_mb} MB)]")
continue
img_mime = detect_image_mime(raw[:32]) or mimetypes.guess_type(path)[0]
if not img_mime or not img_mime.startswith("image/"):
notes.append(f"[Skipped image: unsupported or invalid image format ({p.name})]")
@ -232,10 +254,24 @@ class ContextBuilder:
p = Path(path)
guessed_mime = mimetypes.guess_type(path)[0] or ""
is_audio = guessed_mime.startswith("audio/")
is_video = guessed_mime.startswith("video/")
# Pre-check file size via stat to avoid reading oversized files into memory.
# Determine the relevant byte limit based on detected media type.
_size_limit = 0
if is_audio or is_video:
_size_limit = limits.max_input_audio_bytes if is_audio else limits.max_input_video_bytes
_stat_size_ok = self._file_size_ok(p, _size_limit) if _size_limit else None
if _stat_size_ok is False:
size_mb = _size_limit // (1024 * 1024)
label = "audio" if is_audio else "video"
notes.append(f"[Skipped {label}: file too large ({p.name}, limit {size_mb} MB)]")
continue
try:
raw = p.read_bytes()
except OSError:
notes.append(f"[Skipped file: unable to read ({p.name or path})]")
continue
# Audio detection: by magic bytes or by filename
@ -264,10 +300,9 @@ class ContextBuilder:
blocks.append({"type": "text", "text": f"[audio: {p}]"})
continue
# Video detection: by filename extension
is_video = guessed_mime.startswith("video/")
# Video detection (already classified above)
if is_video:
if supports_video is True:
if supports_video is True and video_mime_compat(guessed_mime):
video_count += 1
if video_count > limits.max_input_videos:
if video_count == limits.max_input_videos + 1:

View File

@ -632,7 +632,7 @@ class AgentLoop:
metadata=meta,
)
_MEDIA_PLACEHOLDER_TYPES = {"image_url", "input_audio", "video_url"}
_MEDIA_PLACEHOLDER_TYPES = LLMProvider._STRIP_MEDIA_TYPES
def _sanitize_persisted_blocks(
self,

View File

@ -169,19 +169,22 @@ def _extract_element_content(element: dict) -> list[str]:
return parts
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
"""Extract text and image keys from Feishu post (rich text) message.
def _extract_post_content(content_json: dict) -> tuple[str, list[str], list[dict]]:
"""Extract text and media info from Feishu post (rich text) message.
Handles three payload shapes:
- Direct: {"title": "...", "content": [[...]]}
- Localized: {"zh_cn": {"title": "...", "content": [...]}}
- Wrapped: {"post": {"zh_cn": {"title": "...", "content": [...]}}}
Returns (text, image_keys, media_items) where media_items is a list of
{"tag": "media", "file_key": "..."} dicts for video/file attachments.
"""
def _parse_block(block: dict) -> tuple[str | None, list[str]]:
def _parse_block(block: dict) -> tuple[str | None, list[str], list[dict]]:
if not isinstance(block, dict) or not isinstance(block.get("content"), list):
return None, []
texts, images = [], []
return None, [], []
texts, images, medias = [], [], []
if title := block.get("title"):
texts.append(title)
for row in block["content"]:
@ -201,43 +204,36 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
texts.append(f"\n```{lang}\n{code_text}\n```\n")
elif tag == "img" and (key := el.get("image_key")):
images.append(key)
return (" ".join(texts).strip() or None), images
elif tag == "media" and el.get("file_key"):
medias.append({"tag": "media", "file_key": el["file_key"]})
return (" ".join(texts).strip() or None), images, medias
# Unwrap optional {"post": ...} envelope
root = content_json
if isinstance(root, dict) and isinstance(root.get("post"), dict):
root = root["post"]
if not isinstance(root, dict):
return "", []
return "", [], []
# Direct format
if "content" in root:
text, imgs = _parse_block(root)
if text or imgs:
return text or "", imgs
text, imgs, medias = _parse_block(root)
if text or imgs or medias:
return text or "", imgs, medias
# Localized: prefer known locales, then fall back to any dict child
for key in ("zh_cn", "en_us", "ja_jp"):
if key in root:
text, imgs = _parse_block(root[key])
if text or imgs:
return text or "", imgs
text, imgs, medias = _parse_block(root[key])
if text or imgs or medias:
return text or "", imgs, medias
for val in root.values():
if isinstance(val, dict):
text, imgs = _parse_block(val)
if text or imgs:
return text or "", imgs
text, imgs, medias = _parse_block(val)
if text or imgs or medias:
return text or "", imgs, medias
return "", []
def _extract_post_text(content_json: dict) -> str:
"""Extract plain text from Feishu post (rich text) message content.
Legacy wrapper for _extract_post_content, returns only text.
"""
text, _ = _extract_post_content(content_json)
return text
return "", [], []
class FeishuConfig(Base):
@ -1027,7 +1023,7 @@ class FeishuChannel(BaseChannel):
file_path = media_dir / filename
file_path.write_bytes(data)
logger.debug("Downloaded {} to {}", msg_type, file_path)
return str(file_path), f"[{msg_type}: {filename}]"
return str(file_path), f"[{msg_type}: {file_path}]"
return None, f"[{msg_type}: download failed]"
@ -1067,7 +1063,7 @@ class FeishuChannel(BaseChannel):
if msg_type == "text":
text = content_json.get("text", "").strip()
elif msg_type == "post":
text, _ = _extract_post_content(content_json)
text, _, _ = _extract_post_content(content_json)
text = text.strip()
else:
text = ""
@ -1542,7 +1538,7 @@ class FeishuChannel(BaseChannel):
content_parts.append(text)
elif msg_type == "post":
text, image_keys = _extract_post_content(content_json)
text, image_keys, media_items = _extract_post_content(content_json)
if text:
content_parts.append(text)
# Download images embedded in post
@ -1553,6 +1549,14 @@ class FeishuChannel(BaseChannel):
if file_path:
media_paths.append(file_path)
content_parts.append(content_text)
# Download media (video/file) embedded in post
for media_item in media_items:
file_path, content_text = await self._download_and_save_media(
"media", media_item, message_id
)
if file_path:
media_paths.append(file_path)
content_parts.append(content_text)
elif msg_type in ("image", "audio", "file", "media"):
file_path, content_text = await self._download_and_save_media(

View File

@ -78,6 +78,7 @@ class Nanobot:
provider_retry_mode=defaults.provider_retry_mode,
web_config=config.tools.web,
exec_config=config.tools.exec,
input_limits=config.tools.input_limits,
restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers,
timezone=defaults.timezone,

View File

@ -369,7 +369,7 @@ class LLMProvider(ABC):
@staticmethod
def _strip_media_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url and input_audio blocks with text placeholders.
"""Replace image_url, input_audio, and video_url blocks with text placeholders.
Returns None if no media blocks were found (no changes needed).
"""

View File

@ -93,6 +93,20 @@ def audio_format_for_api(mime: str) -> str:
return _AUDIO_FORMAT_MAP.get(mime, mime.split("/")[-1])
# Video formats commonly supported by LLM APIs (data URI inline)
_VIDEO_MIME_COMPAT = {
"video/mp4", "video/quicktime", "video/x-m4v",
"video/webm", "video/x-matroska",
}
def video_mime_compat(mime: str | None) -> bool:
"""Check if the video MIME is in the commonly-supported set."""
if not mime:
return False
return mime in _VIDEO_MIME_COMPAT
def build_image_content_blocks(raw: bytes, mime: str, path: str, label: str) -> list[dict[str, Any]]:
"""Build native image blocks plus a short text label."""
b64 = base64.b64encode(raw).decode()

View File

@ -27,10 +27,11 @@ def test_extract_post_content_supports_post_wrapper_shape() -> None:
}
}
text, image_keys = _extract_post_content(payload)
text, image_keys, media_items = _extract_post_content(payload)
assert text == "日报 完成"
assert image_keys == ["img_1"]
assert media_items == []
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
@ -45,10 +46,38 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None:
],
}
text, image_keys = _extract_post_content(payload)
text, image_keys, media_items = _extract_post_content(payload)
assert text == "Daily report"
assert image_keys == ["img_a", "img_b"]
assert media_items == []
def test_extract_post_content_extracts_media_tags() -> None:
payload = {
"title": "",
"content": [
[{"tag": "img", "image_key": "img_1", "width": 345, "height": 34}],
[{"tag": "media", "file_key": "file_v3_0010j_abc", "image_key": "img_v3_0210j_xyz"}],
],
}
text, image_keys, media_items = _extract_post_content(payload)
assert image_keys == ["img_1"]
assert media_items == [{"tag": "media", "file_key": "file_v3_0010j_abc"}]
def test_extract_post_content_ignores_media_without_file_key() -> None:
payload = {
"content": [
[{"tag": "media"}],
],
}
text, image_keys, media_items = _extract_post_content(payload)
assert media_items == []
def test_register_optional_event_keeps_builder_when_method_missing() -> None:

View File

@ -338,11 +338,10 @@ class TestInputLimitsAudioVideo:
assert audio_blocks[0]["input_audio"]["format"] == "mp3"
def test_missing_file_gracefully_skipped(self, ctx, tmp_path):
"""Missing file should be gracefully skipped."""
"""Missing file should be skipped with a visible note."""
result = ctx._build_user_content("hello", [str(tmp_path / "ghost.wav")], supports_audio=True)
# Missing file is silently skipped (non-image path uses continue on OSError)
assert isinstance(result, str)
assert result == "hello"
assert "[Skipped file: unable to read" in result
# ── _strip_media_content ──────────────────────────────────────────────
@ -495,3 +494,175 @@ class TestCodexAudioConversion:
audio_items = [i for i in result["content"] if i.get("type") == "input_audio"]
assert len(audio_items) == 1
assert audio_items[0]["input_audio"]["data"] == "abc123"
def test_video_url_converted_to_text_placeholder(self):
from nanobot.providers.openai_codex_provider import _convert_user_message
content = [
{"type": "video_url", "video_url": {"url": "data:video/mp4;base64,abc"},
"_meta": {"path": "/video.mp4"}},
{"type": "text", "text": "watch"},
]
result = _convert_user_message(content)
text_items = [i for i in result["content"] if i.get("type") == "input_text"]
assert any("[video:" in i.get("text", "") for i in text_items)
# ── New tests for review fixes ──────────────────────────────────────────
class TestSupportsVisionFalse:
"""Tests for supports_vision=False (image downgrade to placeholder)."""
@pytest.fixture
def ctx(self, tmp_path):
return ContextBuilder(tmp_path, timezone="UTC")
def _make_png(self, size: int = 64) -> bytes:
import struct, zlib
header = b"\x89PNG\r\n\x1a\n"
ihdr_data = struct.pack(">IIBBBBB", 1, 1, 8, 2, 0, 0, 0)
ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF
ihdr = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc)
raw = b"\x00\x00\x00\x00"
idat_crc = zlib.crc32(b"IDAT" + raw) & 0xFFFFFFFF
idat = struct.pack(">I", len(raw)) + b"IDAT" + raw + struct.pack(">I", idat_crc)
iend_crc = zlib.crc32(b"IEND") & 0xFFFFFFFF
iend = struct.pack(">I", 0) + b"IEND" + struct.pack(">I", iend_crc)
return header + ihdr + idat + iend
def test_vision_false_downgrades_to_placeholder(self, ctx, tmp_path):
img_path = tmp_path / "test.png"
img_path.write_bytes(self._make_png())
result = ctx._build_user_content("look", [str(img_path)], supports_vision=False)
assert isinstance(result, list)
assert not any(b.get("type") == "image_url" for b in result)
assert any("[image:" in (b.get("text") or "") for b in result)
def test_vision_false_no_file_read(self, ctx, tmp_path):
"""With supports_vision=False, file should not be read (no crash on missing)."""
missing = tmp_path / "nonexistent.png"
result = ctx._build_user_content("look", [str(missing)], supports_vision=False)
assert isinstance(result, list)
assert any("[image:" in (b.get("text") or "") for b in result)
class TestAudioVideoCountLimits:
"""Tests for max_input_audios / max_input_videos count enforcement."""
@pytest.fixture
def ctx(self, tmp_path):
return ContextBuilder(tmp_path, timezone="UTC",
input_limits=InputLimitsConfig(
max_input_audios=1,
max_input_videos=1,
max_input_audio_bytes=10 * 1024 * 1024,
max_input_video_bytes=20 * 1024 * 1024,
))
def _make_wav(self) -> bytes:
data = b"\x00\x00"
fmt_chunk = (
b"\x01\x00" + (1).to_bytes(2, "little") + (44100).to_bytes(4, "little")
+ (88200).to_bytes(4, "little") + (2).to_bytes(2, "little")
+ (16).to_bytes(2, "little")
)
return (
b"RIFF" + (36 + len(data)).to_bytes(4, "little") + b"WAVE"
+ b"fmt " + (16).to_bytes(4, "little") + fmt_chunk
+ b"data" + len(data).to_bytes(4, "little") + data
)
def _make_mp4(self) -> bytes:
ftyp_data = b"isom" + b"\x00" * 12
return (8 + len(ftyp_data)).to_bytes(4, "big") + b"ftyp" + ftyp_data
def test_audio_count_limit_enforced(self, ctx, tmp_path):
"""Only first audio should be accepted; second should be skipped."""
wav1 = tmp_path / "a1.wav"
wav1.write_bytes(self._make_wav())
wav2 = tmp_path / "a2.wav"
wav2.write_bytes(self._make_wav())
result = ctx._build_user_content("listen", [str(wav1), str(wav2)], supports_audio=True)
# Should have note about skip + one audio block
if isinstance(result, list):
audio_blocks = [b for b in result if b.get("type") == "input_audio"]
assert len(audio_blocks) == 1
text_blocks = [b for b in result if b.get("type") == "text"]
notes_text = " ".join(b.get("text", "") for b in text_blocks)
assert "Skipped audio" in notes_text
else:
# All skipped, result is string
assert "Skipped audio" in result
def test_video_count_limit_enforced(self, ctx, tmp_path):
"""Only first video should be accepted; second should be skipped."""
mp4_1 = tmp_path / "v1.mp4"
mp4_1.write_bytes(self._make_mp4())
mp4_2 = tmp_path / "v2.mp4"
mp4_2.write_bytes(self._make_mp4())
result = ctx._build_user_content("watch", [str(mp4_1), str(mp4_2)], supports_video=True)
if isinstance(result, list):
video_blocks = [b for b in result if b.get("type") == "video_url"]
assert len(video_blocks) == 1
text_blocks = [b for b in result if b.get("type") == "text"]
notes_text = " ".join(b.get("text", "") for b in text_blocks)
assert "Skipped video" in notes_text
else:
assert "Skipped video" in result
class TestVideoMimeCompat:
"""Tests for video_mime_compat function."""
def test_compatible_mp4(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat("video/mp4") is True
def test_compatible_webm(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat("video/webm") is True
def test_compatible_quicktime(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat("video/quicktime") is True
def test_incompatible_avi(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat("video/x-msvideo") is False
def test_none(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat(None) is False
class TestSupportsAudioCaseInsensitive:
"""Case insensitivity for supports_audio / supports_video."""
def test_audio_case_insensitive(self):
d = AgentDefaults(audio_models=["GPT-4o"])
assert d.supports_audio("openai/gpt-4o-audio") is True
def test_video_case_insensitive(self):
d = AgentDefaults(video_models=["GLM-5V"])
assert d.supports_video("zhipu/glm-5v-turbo") is True
class TestNonImageOSErrorNote:
"""Non-image media OSError should produce a visible note."""
@pytest.fixture
def ctx(self, tmp_path):
return ContextBuilder(tmp_path, timezone="UTC")
def test_missing_audio_produces_note(self, ctx, tmp_path):
result = ctx._build_user_content(
"hello", [str(tmp_path / "missing.wav")], supports_audio=True
)
assert isinstance(result, str)
assert "[Skipped file: unable to read" in result
def test_missing_video_produces_note(self, ctx, tmp_path):
result = ctx._build_user_content(
"hello", [str(tmp_path / "missing.mp4")], supports_video=True
)
assert isinstance(result, str)
assert "[Skipped file: unable to read" in result