fix(email): bound outbound attachment handling

maintainer edit: apply the existing email attachment count and size limits to outbound media, and include visible fallback notes when an attachment cannot be sent.
This commit is contained in:
chengyongru 2026-06-02 18:18:52 +08:00 committed by Xubin Ren
parent 82a3fd03b1
commit b2ae5d936f
2 changed files with 152 additions and 22 deletions

View File

@ -208,33 +208,61 @@ class EmailChannel(BaseChannel):
if override:
subject = override
attachments: list[tuple[bytes, str, str, str]] = []
failed_attachments: list[str] = []
max_attachment_size = max(0, int(self.config.max_attachment_size))
max_attachment_count = max(0, int(self.config.max_attachments_per_email))
for media_path in msg.media or []:
path = Path(media_path)
filename = path.name or "attachment"
if len(attachments) >= max_attachment_count:
failed_attachments.append(f"[attachment: {filename} - too many attachments]")
self.logger.warning("Attachment count limit reached, skipping: {}", media_path)
continue
if not path.is_file():
failed_attachments.append(f"[attachment: {filename} - send failed]")
self.logger.warning("Attachment not found, skipping: {}", media_path)
continue
try:
size = path.stat().st_size
if max_attachment_size <= 0 or size > max_attachment_size:
failed_attachments.append(f"[attachment: {filename} - too large]")
self.logger.warning(
"Attachment too large, skipping: {} ({} > {} bytes)",
media_path,
size,
max_attachment_size,
)
continue
data = path.read_bytes()
ctype, _ = mimetypes.guess_type(str(path))
if ctype is None:
ctype = "application/octet-stream"
maintype, subtype = ctype.split("/", 1)
attachments.append((data, maintype, subtype, filename))
self.logger.info("Attached file: {}", filename)
except Exception:
failed_attachments.append(f"[attachment: {filename} - send failed]")
self.logger.exception("Failed to attach file {}", media_path)
content = msg.content or ""
if failed_attachments:
fallback = "\n".join(failed_attachments)
content = f"{content.rstrip()}\n\n{fallback}" if content.strip() else fallback
email_msg = EmailMessage()
email_msg["From"] = self.config.from_address or self.config.smtp_username or self.config.imap_username
email_msg["To"] = to_addr
email_msg["Subject"] = subject
email_msg.set_content(msg.content or "")
email_msg.set_content(content)
# Attach media files
for media_path in msg.media or []:
path = Path(media_path)
if not path.is_file():
self.logger.warning("Attachment not found, skipping: {}", media_path)
continue
try:
data = path.read_bytes()
ctype, encoding = mimetypes.guess_type(str(path))
if ctype is None:
ctype = "application/octet-stream"
maintype, subtype = ctype.split("/", 1)
email_msg.add_attachment(
data,
maintype=maintype,
subtype=subtype,
filename=path.name,
)
self.logger.info("Attached file: {}", path.name)
except Exception:
self.logger.exception("Failed to attach file {}", media_path)
for data, maintype, subtype, filename in attachments:
email_msg.add_attachment(
data,
maintype=maintype,
subtype=subtype,
filename=filename,
)
in_reply_to = self._last_message_id_by_chat.get(to_addr)
if in_reply_to:

View File

@ -1175,6 +1175,108 @@ async def test_send_skips_missing_attachment_file(tmp_path, monkeypatch) -> None
# Only the existing file should be attached
assert len(attachment_parts) == 1
assert attachment_parts[0].get_filename() == "real.txt"
body = sent.get_body(preferencelist=("plain",))
assert body is not None
assert "[attachment: nonexistent.pdf - send failed]" in body.get_content()
@pytest.mark.asyncio
async def test_send_skips_oversized_attachment_file(tmp_path, monkeypatch) -> None:
"""Attachment exceeding max_attachment_size is skipped with a visible note."""
sent_messages: list[EmailMessage] = []
class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.timeout = timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def starttls(self, context=None):
return None
def login(self, _user: str, _pw: str):
return None
def send_message(self, msg: EmailMessage):
sent_messages.append(msg)
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", lambda h, p, timeout=30: FakeSMTP(h, p, timeout=timeout))
attachment = tmp_path / "too-large.bin"
attachment.write_bytes(b"1234")
channel = EmailChannel(_make_config(max_attachment_size=3), MessageBus())
await channel.send(
OutboundMessage(
channel="email",
chat_id="alice@example.com",
content="Attachment should be skipped.",
media=[str(attachment)],
)
)
assert len(sent_messages) == 1
sent = sent_messages[0]
assert not sent.is_multipart()
assert "[attachment: too-large.bin - too large]" in sent.get_content()
@pytest.mark.asyncio
async def test_send_limits_outbound_attachment_count(tmp_path, monkeypatch) -> None:
"""Only max_attachments_per_email outbound attachments are included."""
sent_messages: list[EmailMessage] = []
class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.timeout = timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def starttls(self, context=None):
return None
def login(self, _user: str, _pw: str):
return None
def send_message(self, msg: EmailMessage):
sent_messages.append(msg)
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", lambda h, p, timeout=30: FakeSMTP(h, p, timeout=timeout))
file1 = tmp_path / "first.txt"
file1.write_text("first")
file2 = tmp_path / "second.txt"
file2.write_text("second")
channel = EmailChannel(_make_config(max_attachments_per_email=1), MessageBus())
await channel.send(
OutboundMessage(
channel="email",
chat_id="alice@example.com",
content="Only one attachment should be sent.",
media=[str(file1), str(file2)],
)
)
assert len(sent_messages) == 1
sent = sent_messages[0]
attachment_parts = []
for part in sent.walk():
if part.get_content_disposition() == "attachment":
attachment_parts.append(part)
assert len(attachment_parts) == 1
assert attachment_parts[0].get_filename() == "first.txt"
body = sent.get_body(preferencelist=("plain",))
assert body is not None
assert "[attachment: second.txt - too many attachments]" in body.get_content()
@pytest.mark.asyncio