fix(multimodal): image OOM guard, Feishu post media extraction, vision fallback

- Add file size pre-check via stat() before read_bytes() to prevent OOM
  on oversized images/audio/video
- Fix _extract_post_content to extract media tags (file_key) from Feishu
  post messages so videos are no longer silently dropped
- Add supports_vision=False guard to downgrade images to text placeholders
- Add video_mime_compat() for video format validation
- Use full file path in content_text so model read_file works if needed
- Pass input_limits to AgentLoop in nanobot.py facade
- Deduplicate _MEDIA_PLACEHOLDER_TYPES from LLMProvider constant
- Remove unused _extract_post_text legacy wrapper
- Add 14 new tests covering vision fallback, count limits, video compat
This commit is contained in:
chengyongru 2026-04-09 01:13:40 +08:00
parent b9346b0d59
commit c053f9eba8
8 changed files with 298 additions and 44 deletions

View File

@ -2,6 +2,7 @@
import base64 import base64
import mimetypes import mimetypes
import os
import platform import platform
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -16,6 +17,7 @@ from nanobot.utils.helpers import (
current_time_str, current_time_str,
detect_audio_mime, detect_audio_mime,
detect_image_mime, detect_image_mime,
video_mime_compat,
) )
from nanobot.utils.prompt_templates import render_template from nanobot.utils.prompt_templates import render_template
@ -33,6 +35,18 @@ class ContextBuilder:
self.skills = SkillsLoader(workspace) self.skills = SkillsLoader(workspace)
self.input_limits = input_limits or InputLimitsConfig() self.input_limits = input_limits or InputLimitsConfig()
@staticmethod
def _file_size_ok(p: Path, max_bytes: int) -> bool | None:
"""Check file size via stat without reading into memory.
Returns True if size is within limit, False if oversized,
None if file cannot be stat'd (caller should try read_bytes instead).
"""
try:
return os.stat(p).st_size <= max_bytes
except OSError:
return None
def build_system_prompt(self, skill_names: list[str] | None = None) -> str: def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
"""Build the system prompt from identity, bootstrap files, memory, and skills.""" """Build the system prompt from identity, bootstrap files, memory, and skills."""
parts = [self._get_identity()] parts = [self._get_identity()]
@ -171,7 +185,8 @@ class ContextBuilder:
text: The user text message. text: The user text message.
media: List of file paths to media files. media: List of file paths to media files.
supports_vision: True=model supports images, False=use placeholder, supports_vision: True=model supports images, False=use placeholder,
None=unconfigured (send images as before). None=unconfigured (send images as before, let
provider/retry handle degradation).
supports_audio: True=model supports native audio, False/None=skip supports_audio: True=model supports native audio, False/None=skip
(channel layer already transcribed). (channel layer already transcribed).
supports_video: True=model supports native video, False/None=use supports_video: True=model supports native video, False/None=use
@ -210,15 +225,22 @@ class ContextBuilder:
# Process images # Process images
for path in image_media: for path in image_media:
p = Path(path) p = Path(path)
# When explicitly marked as non-vision, downgrade to text placeholder
if supports_vision is False:
blocks.append({"type": "text", "text": f"[image: {p}]"})
continue
size_ok = self._file_size_ok(p, limits.max_input_image_bytes)
if size_ok is False:
size_mb = limits.max_input_image_bytes // (1024 * 1024)
notes.append(f"[Skipped image: file too large ({p.name}, limit {size_mb} MB)]")
continue
try: try:
raw = p.read_bytes() raw = p.read_bytes()
except OSError: except OSError:
notes.append(f"[Skipped image: unable to read ({p.name or path})]") notes.append(f"[Skipped image: unable to read ({p.name or path})]")
continue continue
if len(raw) > limits.max_input_image_bytes:
size_mb = limits.max_input_image_bytes // (1024 * 1024)
notes.append(f"[Skipped image: file too large ({p.name}, limit {size_mb} MB)]")
continue
img_mime = detect_image_mime(raw[:32]) or mimetypes.guess_type(path)[0] img_mime = detect_image_mime(raw[:32]) or mimetypes.guess_type(path)[0]
if not img_mime or not img_mime.startswith("image/"): if not img_mime or not img_mime.startswith("image/"):
notes.append(f"[Skipped image: unsupported or invalid image format ({p.name})]") notes.append(f"[Skipped image: unsupported or invalid image format ({p.name})]")
@ -232,10 +254,24 @@ class ContextBuilder:
p = Path(path) p = Path(path)
guessed_mime = mimetypes.guess_type(path)[0] or "" guessed_mime = mimetypes.guess_type(path)[0] or ""
is_audio = guessed_mime.startswith("audio/") is_audio = guessed_mime.startswith("audio/")
is_video = guessed_mime.startswith("video/")
# Pre-check file size via stat to avoid reading oversized files into memory.
# Determine the relevant byte limit based on detected media type.
_size_limit = 0
if is_audio or is_video:
_size_limit = limits.max_input_audio_bytes if is_audio else limits.max_input_video_bytes
_stat_size_ok = self._file_size_ok(p, _size_limit) if _size_limit else None
if _stat_size_ok is False:
size_mb = _size_limit // (1024 * 1024)
label = "audio" if is_audio else "video"
notes.append(f"[Skipped {label}: file too large ({p.name}, limit {size_mb} MB)]")
continue
try: try:
raw = p.read_bytes() raw = p.read_bytes()
except OSError: except OSError:
notes.append(f"[Skipped file: unable to read ({p.name or path})]")
continue continue
# Audio detection: by magic bytes or by filename # Audio detection: by magic bytes or by filename
@ -264,10 +300,9 @@ class ContextBuilder:
blocks.append({"type": "text", "text": f"[audio: {p}]"}) blocks.append({"type": "text", "text": f"[audio: {p}]"})
continue continue
# Video detection: by filename extension # Video detection (already classified above)
is_video = guessed_mime.startswith("video/")
if is_video: if is_video:
if supports_video is True: if supports_video is True and video_mime_compat(guessed_mime):
video_count += 1 video_count += 1
if video_count > limits.max_input_videos: if video_count > limits.max_input_videos:
if video_count == limits.max_input_videos + 1: if video_count == limits.max_input_videos + 1:

View File

@ -632,7 +632,7 @@ class AgentLoop:
metadata=meta, metadata=meta,
) )
_MEDIA_PLACEHOLDER_TYPES = {"image_url", "input_audio", "video_url"} _MEDIA_PLACEHOLDER_TYPES = LLMProvider._STRIP_MEDIA_TYPES
def _sanitize_persisted_blocks( def _sanitize_persisted_blocks(
self, self,

View File

@ -169,19 +169,22 @@ def _extract_element_content(element: dict) -> list[str]:
return parts return parts
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]: def _extract_post_content(content_json: dict) -> tuple[str, list[str], list[dict]]:
"""Extract text and image keys from Feishu post (rich text) message. """Extract text and media info from Feishu post (rich text) message.
Handles three payload shapes: Handles three payload shapes:
- Direct: {"title": "...", "content": [[...]]} - Direct: {"title": "...", "content": [[...]]}
- Localized: {"zh_cn": {"title": "...", "content": [...]}} - Localized: {"zh_cn": {"title": "...", "content": [...]}}
- Wrapped: {"post": {"zh_cn": {"title": "...", "content": [...]}}} - Wrapped: {"post": {"zh_cn": {"title": "...", "content": [...]}}}
Returns (text, image_keys, media_items) where media_items is a list of
{"tag": "media", "file_key": "..."} dicts for video/file attachments.
""" """
def _parse_block(block: dict) -> tuple[str | None, list[str]]: def _parse_block(block: dict) -> tuple[str | None, list[str], list[dict]]:
if not isinstance(block, dict) or not isinstance(block.get("content"), list): if not isinstance(block, dict) or not isinstance(block.get("content"), list):
return None, [] return None, [], []
texts, images = [], [] texts, images, medias = [], [], []
if title := block.get("title"): if title := block.get("title"):
texts.append(title) texts.append(title)
for row in block["content"]: for row in block["content"]:
@ -201,43 +204,36 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
texts.append(f"\n```{lang}\n{code_text}\n```\n") texts.append(f"\n```{lang}\n{code_text}\n```\n")
elif tag == "img" and (key := el.get("image_key")): elif tag == "img" and (key := el.get("image_key")):
images.append(key) images.append(key)
return (" ".join(texts).strip() or None), images elif tag == "media" and el.get("file_key"):
medias.append({"tag": "media", "file_key": el["file_key"]})
return (" ".join(texts).strip() or None), images, medias
# Unwrap optional {"post": ...} envelope # Unwrap optional {"post": ...} envelope
root = content_json root = content_json
if isinstance(root, dict) and isinstance(root.get("post"), dict): if isinstance(root, dict) and isinstance(root.get("post"), dict):
root = root["post"] root = root["post"]
if not isinstance(root, dict): if not isinstance(root, dict):
return "", [] return "", [], []
# Direct format # Direct format
if "content" in root: if "content" in root:
text, imgs = _parse_block(root) text, imgs, medias = _parse_block(root)
if text or imgs: if text or imgs or medias:
return text or "", imgs return text or "", imgs, medias
# Localized: prefer known locales, then fall back to any dict child # Localized: prefer known locales, then fall back to any dict child
for key in ("zh_cn", "en_us", "ja_jp"): for key in ("zh_cn", "en_us", "ja_jp"):
if key in root: if key in root:
text, imgs = _parse_block(root[key]) text, imgs, medias = _parse_block(root[key])
if text or imgs: if text or imgs or medias:
return text or "", imgs return text or "", imgs, medias
for val in root.values(): for val in root.values():
if isinstance(val, dict): if isinstance(val, dict):
text, imgs = _parse_block(val) text, imgs, medias = _parse_block(val)
if text or imgs: if text or imgs or medias:
return text or "", imgs return text or "", imgs, medias
return "", [] return "", [], []
def _extract_post_text(content_json: dict) -> str:
"""Extract plain text from Feishu post (rich text) message content.
Legacy wrapper for _extract_post_content, returns only text.
"""
text, _ = _extract_post_content(content_json)
return text
class FeishuConfig(Base): class FeishuConfig(Base):
@ -1027,7 +1023,7 @@ class FeishuChannel(BaseChannel):
file_path = media_dir / filename file_path = media_dir / filename
file_path.write_bytes(data) file_path.write_bytes(data)
logger.debug("Downloaded {} to {}", msg_type, file_path) logger.debug("Downloaded {} to {}", msg_type, file_path)
return str(file_path), f"[{msg_type}: {filename}]" return str(file_path), f"[{msg_type}: {file_path}]"
return None, f"[{msg_type}: download failed]" return None, f"[{msg_type}: download failed]"
@ -1067,7 +1063,7 @@ class FeishuChannel(BaseChannel):
if msg_type == "text": if msg_type == "text":
text = content_json.get("text", "").strip() text = content_json.get("text", "").strip()
elif msg_type == "post": elif msg_type == "post":
text, _ = _extract_post_content(content_json) text, _, _ = _extract_post_content(content_json)
text = text.strip() text = text.strip()
else: else:
text = "" text = ""
@ -1542,7 +1538,7 @@ class FeishuChannel(BaseChannel):
content_parts.append(text) content_parts.append(text)
elif msg_type == "post": elif msg_type == "post":
text, image_keys = _extract_post_content(content_json) text, image_keys, media_items = _extract_post_content(content_json)
if text: if text:
content_parts.append(text) content_parts.append(text)
# Download images embedded in post # Download images embedded in post
@ -1553,6 +1549,14 @@ class FeishuChannel(BaseChannel):
if file_path: if file_path:
media_paths.append(file_path) media_paths.append(file_path)
content_parts.append(content_text) content_parts.append(content_text)
# Download media (video/file) embedded in post
for media_item in media_items:
file_path, content_text = await self._download_and_save_media(
"media", media_item, message_id
)
if file_path:
media_paths.append(file_path)
content_parts.append(content_text)
elif msg_type in ("image", "audio", "file", "media"): elif msg_type in ("image", "audio", "file", "media"):
file_path, content_text = await self._download_and_save_media( file_path, content_text = await self._download_and_save_media(

View File

@ -78,6 +78,7 @@ class Nanobot:
provider_retry_mode=defaults.provider_retry_mode, provider_retry_mode=defaults.provider_retry_mode,
web_config=config.tools.web, web_config=config.tools.web,
exec_config=config.tools.exec, exec_config=config.tools.exec,
input_limits=config.tools.input_limits,
restrict_to_workspace=config.tools.restrict_to_workspace, restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers, mcp_servers=config.tools.mcp_servers,
timezone=defaults.timezone, timezone=defaults.timezone,

View File

@ -369,7 +369,7 @@ class LLMProvider(ABC):
@staticmethod @staticmethod
def _strip_media_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: def _strip_media_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url and input_audio blocks with text placeholders. """Replace image_url, input_audio, and video_url blocks with text placeholders.
Returns None if no media blocks were found (no changes needed). Returns None if no media blocks were found (no changes needed).
""" """

View File

@ -93,6 +93,20 @@ def audio_format_for_api(mime: str) -> str:
return _AUDIO_FORMAT_MAP.get(mime, mime.split("/")[-1]) return _AUDIO_FORMAT_MAP.get(mime, mime.split("/")[-1])
# Video formats commonly supported by LLM APIs (data URI inline)
_VIDEO_MIME_COMPAT = {
"video/mp4", "video/quicktime", "video/x-m4v",
"video/webm", "video/x-matroska",
}
def video_mime_compat(mime: str | None) -> bool:
"""Check if the video MIME is in the commonly-supported set."""
if not mime:
return False
return mime in _VIDEO_MIME_COMPAT
def build_image_content_blocks(raw: bytes, mime: str, path: str, label: str) -> list[dict[str, Any]]: def build_image_content_blocks(raw: bytes, mime: str, path: str, label: str) -> list[dict[str, Any]]:
"""Build native image blocks plus a short text label.""" """Build native image blocks plus a short text label."""
b64 = base64.b64encode(raw).decode() b64 = base64.b64encode(raw).decode()

View File

@ -27,10 +27,11 @@ def test_extract_post_content_supports_post_wrapper_shape() -> None:
} }
} }
text, image_keys = _extract_post_content(payload) text, image_keys, media_items = _extract_post_content(payload)
assert text == "日报 完成" assert text == "日报 完成"
assert image_keys == ["img_1"] assert image_keys == ["img_1"]
assert media_items == []
def test_extract_post_content_keeps_direct_shape_behavior() -> None: def test_extract_post_content_keeps_direct_shape_behavior() -> None:
@ -45,10 +46,38 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None:
], ],
} }
text, image_keys = _extract_post_content(payload) text, image_keys, media_items = _extract_post_content(payload)
assert text == "Daily report" assert text == "Daily report"
assert image_keys == ["img_a", "img_b"] assert image_keys == ["img_a", "img_b"]
assert media_items == []
def test_extract_post_content_extracts_media_tags() -> None:
payload = {
"title": "",
"content": [
[{"tag": "img", "image_key": "img_1", "width": 345, "height": 34}],
[{"tag": "media", "file_key": "file_v3_0010j_abc", "image_key": "img_v3_0210j_xyz"}],
],
}
text, image_keys, media_items = _extract_post_content(payload)
assert image_keys == ["img_1"]
assert media_items == [{"tag": "media", "file_key": "file_v3_0010j_abc"}]
def test_extract_post_content_ignores_media_without_file_key() -> None:
payload = {
"content": [
[{"tag": "media"}],
],
}
text, image_keys, media_items = _extract_post_content(payload)
assert media_items == []
def test_register_optional_event_keeps_builder_when_method_missing() -> None: def test_register_optional_event_keeps_builder_when_method_missing() -> None:

View File

@ -338,11 +338,10 @@ class TestInputLimitsAudioVideo:
assert audio_blocks[0]["input_audio"]["format"] == "mp3" assert audio_blocks[0]["input_audio"]["format"] == "mp3"
def test_missing_file_gracefully_skipped(self, ctx, tmp_path): def test_missing_file_gracefully_skipped(self, ctx, tmp_path):
"""Missing file should be gracefully skipped.""" """Missing file should be skipped with a visible note."""
result = ctx._build_user_content("hello", [str(tmp_path / "ghost.wav")], supports_audio=True) result = ctx._build_user_content("hello", [str(tmp_path / "ghost.wav")], supports_audio=True)
# Missing file is silently skipped (non-image path uses continue on OSError)
assert isinstance(result, str) assert isinstance(result, str)
assert result == "hello" assert "[Skipped file: unable to read" in result
# ── _strip_media_content ────────────────────────────────────────────── # ── _strip_media_content ──────────────────────────────────────────────
@ -495,3 +494,175 @@ class TestCodexAudioConversion:
audio_items = [i for i in result["content"] if i.get("type") == "input_audio"] audio_items = [i for i in result["content"] if i.get("type") == "input_audio"]
assert len(audio_items) == 1 assert len(audio_items) == 1
assert audio_items[0]["input_audio"]["data"] == "abc123" assert audio_items[0]["input_audio"]["data"] == "abc123"
def test_video_url_converted_to_text_placeholder(self):
from nanobot.providers.openai_codex_provider import _convert_user_message
content = [
{"type": "video_url", "video_url": {"url": "data:video/mp4;base64,abc"},
"_meta": {"path": "/video.mp4"}},
{"type": "text", "text": "watch"},
]
result = _convert_user_message(content)
text_items = [i for i in result["content"] if i.get("type") == "input_text"]
assert any("[video:" in i.get("text", "") for i in text_items)
# ── New tests for review fixes ──────────────────────────────────────────
class TestSupportsVisionFalse:
"""Tests for supports_vision=False (image downgrade to placeholder)."""
@pytest.fixture
def ctx(self, tmp_path):
return ContextBuilder(tmp_path, timezone="UTC")
def _make_png(self, size: int = 64) -> bytes:
import struct, zlib
header = b"\x89PNG\r\n\x1a\n"
ihdr_data = struct.pack(">IIBBBBB", 1, 1, 8, 2, 0, 0, 0)
ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF
ihdr = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc)
raw = b"\x00\x00\x00\x00"
idat_crc = zlib.crc32(b"IDAT" + raw) & 0xFFFFFFFF
idat = struct.pack(">I", len(raw)) + b"IDAT" + raw + struct.pack(">I", idat_crc)
iend_crc = zlib.crc32(b"IEND") & 0xFFFFFFFF
iend = struct.pack(">I", 0) + b"IEND" + struct.pack(">I", iend_crc)
return header + ihdr + idat + iend
def test_vision_false_downgrades_to_placeholder(self, ctx, tmp_path):
img_path = tmp_path / "test.png"
img_path.write_bytes(self._make_png())
result = ctx._build_user_content("look", [str(img_path)], supports_vision=False)
assert isinstance(result, list)
assert not any(b.get("type") == "image_url" for b in result)
assert any("[image:" in (b.get("text") or "") for b in result)
def test_vision_false_no_file_read(self, ctx, tmp_path):
"""With supports_vision=False, file should not be read (no crash on missing)."""
missing = tmp_path / "nonexistent.png"
result = ctx._build_user_content("look", [str(missing)], supports_vision=False)
assert isinstance(result, list)
assert any("[image:" in (b.get("text") or "") for b in result)
class TestAudioVideoCountLimits:
"""Tests for max_input_audios / max_input_videos count enforcement."""
@pytest.fixture
def ctx(self, tmp_path):
return ContextBuilder(tmp_path, timezone="UTC",
input_limits=InputLimitsConfig(
max_input_audios=1,
max_input_videos=1,
max_input_audio_bytes=10 * 1024 * 1024,
max_input_video_bytes=20 * 1024 * 1024,
))
def _make_wav(self) -> bytes:
data = b"\x00\x00"
fmt_chunk = (
b"\x01\x00" + (1).to_bytes(2, "little") + (44100).to_bytes(4, "little")
+ (88200).to_bytes(4, "little") + (2).to_bytes(2, "little")
+ (16).to_bytes(2, "little")
)
return (
b"RIFF" + (36 + len(data)).to_bytes(4, "little") + b"WAVE"
+ b"fmt " + (16).to_bytes(4, "little") + fmt_chunk
+ b"data" + len(data).to_bytes(4, "little") + data
)
def _make_mp4(self) -> bytes:
ftyp_data = b"isom" + b"\x00" * 12
return (8 + len(ftyp_data)).to_bytes(4, "big") + b"ftyp" + ftyp_data
def test_audio_count_limit_enforced(self, ctx, tmp_path):
"""Only first audio should be accepted; second should be skipped."""
wav1 = tmp_path / "a1.wav"
wav1.write_bytes(self._make_wav())
wav2 = tmp_path / "a2.wav"
wav2.write_bytes(self._make_wav())
result = ctx._build_user_content("listen", [str(wav1), str(wav2)], supports_audio=True)
# Should have note about skip + one audio block
if isinstance(result, list):
audio_blocks = [b for b in result if b.get("type") == "input_audio"]
assert len(audio_blocks) == 1
text_blocks = [b for b in result if b.get("type") == "text"]
notes_text = " ".join(b.get("text", "") for b in text_blocks)
assert "Skipped audio" in notes_text
else:
# All skipped, result is string
assert "Skipped audio" in result
def test_video_count_limit_enforced(self, ctx, tmp_path):
"""Only first video should be accepted; second should be skipped."""
mp4_1 = tmp_path / "v1.mp4"
mp4_1.write_bytes(self._make_mp4())
mp4_2 = tmp_path / "v2.mp4"
mp4_2.write_bytes(self._make_mp4())
result = ctx._build_user_content("watch", [str(mp4_1), str(mp4_2)], supports_video=True)
if isinstance(result, list):
video_blocks = [b for b in result if b.get("type") == "video_url"]
assert len(video_blocks) == 1
text_blocks = [b for b in result if b.get("type") == "text"]
notes_text = " ".join(b.get("text", "") for b in text_blocks)
assert "Skipped video" in notes_text
else:
assert "Skipped video" in result
class TestVideoMimeCompat:
"""Tests for video_mime_compat function."""
def test_compatible_mp4(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat("video/mp4") is True
def test_compatible_webm(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat("video/webm") is True
def test_compatible_quicktime(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat("video/quicktime") is True
def test_incompatible_avi(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat("video/x-msvideo") is False
def test_none(self):
from nanobot.utils.helpers import video_mime_compat
assert video_mime_compat(None) is False
class TestSupportsAudioCaseInsensitive:
"""Case insensitivity for supports_audio / supports_video."""
def test_audio_case_insensitive(self):
d = AgentDefaults(audio_models=["GPT-4o"])
assert d.supports_audio("openai/gpt-4o-audio") is True
def test_video_case_insensitive(self):
d = AgentDefaults(video_models=["GLM-5V"])
assert d.supports_video("zhipu/glm-5v-turbo") is True
class TestNonImageOSErrorNote:
"""Non-image media OSError should produce a visible note."""
@pytest.fixture
def ctx(self, tmp_path):
return ContextBuilder(tmp_path, timezone="UTC")
def test_missing_audio_produces_note(self, ctx, tmp_path):
result = ctx._build_user_content(
"hello", [str(tmp_path / "missing.wav")], supports_audio=True
)
assert isinstance(result, str)
assert "[Skipped file: unable to read" in result
def test_missing_video_produces_note(self, ctx, tmp_path):
result = ctx._build_user_content(
"hello", [str(tmp_path / "missing.mp4")], supports_video=True
)
assert isinstance(result, str)
assert "[Skipped file: unable to read" in result