mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 06:45:55 +00:00
317 lines
11 KiB
Python
317 lines
11 KiB
Python
"""Tests for document text extraction utilities."""
|
|
|
|
from pathlib import Path
|
|
|
|
from nanobot.utils.document import (
|
|
SUPPORTED_EXTENSIONS,
|
|
_is_text_extension,
|
|
extract_text,
|
|
)
|
|
|
|
|
|
class TestSupportedExtensions:
|
|
"""Test the SUPPORTED_EXTENSIONS constant."""
|
|
|
|
def test_supported_extensions_include_common_formats(self):
|
|
"""Test that common document formats are included."""
|
|
# Document formats
|
|
assert ".pdf" in SUPPORTED_EXTENSIONS
|
|
assert ".docx" in SUPPORTED_EXTENSIONS
|
|
assert ".xlsx" in SUPPORTED_EXTENSIONS
|
|
assert ".pptx" in SUPPORTED_EXTENSIONS
|
|
|
|
# Text formats
|
|
assert ".txt" in SUPPORTED_EXTENSIONS
|
|
assert ".md" in SUPPORTED_EXTENSIONS
|
|
assert ".csv" in SUPPORTED_EXTENSIONS
|
|
assert ".json" in SUPPORTED_EXTENSIONS
|
|
assert ".yaml" in SUPPORTED_EXTENSIONS
|
|
assert ".yml" in SUPPORTED_EXTENSIONS
|
|
|
|
# Image formats
|
|
assert ".png" in SUPPORTED_EXTENSIONS
|
|
assert ".jpg" in SUPPORTED_EXTENSIONS
|
|
assert ".jpeg" in SUPPORTED_EXTENSIONS
|
|
|
|
|
|
class TestExtractText:
|
|
"""Test the extract_text function."""
|
|
|
|
def test_extract_text_unsupported_returns_none(self, tmp_path: Path):
|
|
"""Test that unsupported file types return None."""
|
|
unsupported_file = tmp_path / "file.xyz"
|
|
unsupported_file.write_text("content")
|
|
|
|
result = extract_text(unsupported_file)
|
|
assert result is None
|
|
|
|
def test_extract_text_file_not_found(self, tmp_path: Path):
|
|
"""Test that non-existent files return error string."""
|
|
missing_file = tmp_path / "nonexistent.txt"
|
|
|
|
result = extract_text(missing_file)
|
|
assert result is not None
|
|
assert "[error: file not found:" in result
|
|
|
|
def test_extract_text_txt_file(self, tmp_path: Path):
|
|
"""Test extracting text from a .txt file."""
|
|
txt_file = tmp_path / "test.txt"
|
|
content = "Hello, world!\nThis is a test."
|
|
txt_file.write_text(content, encoding="utf-8")
|
|
|
|
result = extract_text(txt_file)
|
|
assert result == content
|
|
|
|
def test_extract_text_txt_file_with_truncation(self, tmp_path: Path):
|
|
"""Test that large text files are truncated."""
|
|
txt_file = tmp_path / "large.txt"
|
|
# Create content larger than _MAX_TEXT_LENGTH
|
|
content = "x" * 300_000
|
|
txt_file.write_text(content, encoding="utf-8")
|
|
|
|
result = extract_text(txt_file)
|
|
assert len(result) < 300_000
|
|
assert "(truncated," in result
|
|
assert "chars total)" in result
|
|
|
|
def test_extract_text_md_file(self, tmp_path: Path):
|
|
"""Test extracting text from a .md file."""
|
|
md_file = tmp_path / "test.md"
|
|
content = "# Header\n\nSome markdown content."
|
|
md_file.write_text(content, encoding="utf-8")
|
|
|
|
result = extract_text(md_file)
|
|
assert result == content
|
|
|
|
def test_extract_text_csv_file(self, tmp_path: Path):
|
|
"""Test extracting text from a .csv file."""
|
|
csv_file = tmp_path / "test.csv"
|
|
content = "name,age\nAlice,30\nBob,25"
|
|
csv_file.write_text(content, encoding="utf-8")
|
|
|
|
result = extract_text(csv_file)
|
|
assert result == content
|
|
|
|
def test_extract_text_json_file(self, tmp_path: Path):
|
|
"""Test extracting text from a .json file."""
|
|
json_file = tmp_path / "test.json"
|
|
content = '{"key": "value", "number": 42}'
|
|
json_file.write_text(content, encoding="utf-8")
|
|
|
|
result = extract_text(json_file)
|
|
assert result == content
|
|
|
|
def test_extract_text_xlsx(self, tmp_path: Path):
|
|
"""Test extracting text from an .xlsx file."""
|
|
from openpyxl import Workbook
|
|
|
|
xlsx_file = tmp_path / "test.xlsx"
|
|
wb = Workbook()
|
|
ws = wb.active
|
|
ws.title = "Sheet1"
|
|
ws["A1"] = "Name"
|
|
ws["B1"] = "Age"
|
|
ws["A2"] = "Alice"
|
|
ws["B2"] = 30
|
|
ws["A3"] = "Bob"
|
|
ws["B3"] = 25
|
|
|
|
# Add a second sheet
|
|
ws2 = wb.create_sheet("Sheet2")
|
|
ws2["A1"] = "Product"
|
|
ws2["B1"] = "Price"
|
|
ws2["A2"] = "Widget"
|
|
ws2["B2"] = 9.99
|
|
|
|
wb.save(xlsx_file)
|
|
wb.close()
|
|
|
|
result = extract_text(xlsx_file)
|
|
assert result is not None
|
|
assert "--- Sheet: Sheet1 ---" in result
|
|
assert "--- Sheet: Sheet2 ---" in result
|
|
assert "Alice" in result
|
|
assert "Bob" in result
|
|
assert "Widget" in result
|
|
assert "9.99" in result
|
|
|
|
def test_extract_text_xlsx_empty_sheet(self, tmp_path: Path):
|
|
"""Test extracting text from an .xlsx file with empty sheets."""
|
|
from openpyxl import Workbook
|
|
|
|
xlsx_file = tmp_path / "empty.xlsx"
|
|
wb = Workbook()
|
|
# Clear the default sheet
|
|
wb.remove(wb.active)
|
|
# Add an empty sheet
|
|
wb.create_sheet("EmptySheet")
|
|
wb.save(xlsx_file)
|
|
wb.close()
|
|
|
|
result = extract_text(xlsx_file)
|
|
# Empty sheets should return empty string or header only
|
|
assert result == "--- Sheet: EmptySheet ---" or result == ""
|
|
|
|
def test_extract_text_docx(self, tmp_path: Path):
|
|
"""Test extracting text from a .docx file."""
|
|
from docx import Document
|
|
|
|
docx_file = tmp_path / "test.docx"
|
|
doc = Document()
|
|
doc.add_heading("Test Document", 0)
|
|
doc.add_paragraph("This is paragraph one.")
|
|
doc.add_paragraph("This is paragraph two.")
|
|
doc.save(docx_file)
|
|
|
|
result = extract_text(docx_file)
|
|
assert result is not None
|
|
assert "Test Document" in result
|
|
assert "This is paragraph one." in result
|
|
assert "This is paragraph two." in result
|
|
|
|
def test_extract_text_docx_empty(self, tmp_path: Path):
|
|
"""Test extracting text from an empty .docx file."""
|
|
from docx import Document
|
|
|
|
docx_file = tmp_path / "empty.docx"
|
|
doc = Document()
|
|
doc.save(docx_file)
|
|
|
|
result = extract_text(docx_file)
|
|
assert result == ""
|
|
|
|
def test_extract_text_pptx(self, tmp_path: Path):
|
|
"""Test extracting text from a .pptx file."""
|
|
from pptx import Presentation
|
|
|
|
pptx_file = tmp_path / "test.pptx"
|
|
prs = Presentation()
|
|
|
|
# Slide 1
|
|
slide1 = prs.slides.add_slide(prs.slide_layouts[0])
|
|
for shape in slide1.shapes:
|
|
if hasattr(shape, "text"):
|
|
shape.text = "First Slide Title"
|
|
|
|
# Slide 2
|
|
slide2 = prs.slides.add_slide(prs.slide_layouts[5])
|
|
left = top = width = height = 1000000
|
|
textbox = slide2.shapes.add_textbox(left, top, width, height)
|
|
text_frame = textbox.text_frame
|
|
text_frame.text = "Bullet point content"
|
|
|
|
prs.save(pptx_file)
|
|
|
|
result = extract_text(pptx_file)
|
|
assert result is not None
|
|
assert "--- Slide 1 ---" in result
|
|
assert "--- Slide 2 ---" in result
|
|
# Text content may vary depending on PowerPoint layout defaults
|
|
assert len(result) > 0
|
|
|
|
def test_extract_text_pptx_table(self, tmp_path: Path):
|
|
"""Table cells should be extracted, not silently dropped."""
|
|
from pptx import Presentation
|
|
from pptx.util import Inches
|
|
|
|
pptx_file = tmp_path / "table.pptx"
|
|
prs = Presentation()
|
|
slide = prs.slides.add_slide(prs.slide_layouts[5])
|
|
table = slide.shapes.add_table(
|
|
2, 2, Inches(1), Inches(1), Inches(4), Inches(1)
|
|
).table
|
|
table.cell(0, 0).text = "Header A"
|
|
table.cell(0, 1).text = "Header B"
|
|
table.cell(1, 0).text = "Alice"
|
|
table.cell(1, 1).text = "Bob"
|
|
prs.save(pptx_file)
|
|
|
|
result = extract_text(pptx_file)
|
|
assert result is not None
|
|
assert "Header A" in result
|
|
assert "Header B" in result
|
|
assert "Alice" in result
|
|
assert "Bob" in result
|
|
|
|
def test_extract_text_pptx_grouped_shapes(self, tmp_path: Path):
|
|
"""Text inside grouped shapes must be extracted recursively."""
|
|
from pptx import Presentation
|
|
from pptx.util import Inches
|
|
|
|
pptx_file = tmp_path / "grouped.pptx"
|
|
prs = Presentation()
|
|
slide = prs.slides.add_slide(prs.slide_layouts[5])
|
|
group = slide.shapes.add_group_shape()
|
|
inner = group.shapes.add_textbox(
|
|
Inches(1), Inches(1), Inches(3), Inches(1)
|
|
)
|
|
inner.text_frame.text = "Inside group"
|
|
prs.save(pptx_file)
|
|
|
|
result = extract_text(pptx_file)
|
|
assert result is not None
|
|
assert "Inside group" in result
|
|
|
|
def test_extract_text_pdf_not_found(self, tmp_path: Path):
|
|
"""Test that missing PDF files return error string."""
|
|
missing_pdf = tmp_path / "nonexistent.pdf"
|
|
|
|
result = extract_text(missing_pdf)
|
|
assert result is not None
|
|
assert "[error: file not found:" in result
|
|
|
|
def test_extract_text_image_files(self, tmp_path: Path):
|
|
"""Test that image files return placeholder text."""
|
|
# Create a minimal PNG file (1x1 pixel)
|
|
png_file = tmp_path / "test.png"
|
|
# Minimal valid PNG: 8-byte signature + IHDR + IDAT + IEND
|
|
png_data = (
|
|
b"\x89PNG\r\n\x1a\n"
|
|
b"\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
|
|
b"\x08\x02\x00\x00\x00\x90wS\xde"
|
|
b"\x00\x00\x00\x0cIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01"
|
|
b"\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82"
|
|
)
|
|
png_file.write_bytes(png_data)
|
|
|
|
result = extract_text(png_file)
|
|
assert result is not None
|
|
assert "[image:" in result
|
|
assert "test.png" in result
|
|
|
|
|
|
class TestIsTextExtension:
|
|
"""Test the _is_text_extension helper."""
|
|
|
|
def test_text_extensions_return_true(self):
|
|
"""Test that known text extensions return True."""
|
|
assert _is_text_extension(".txt") is True
|
|
assert _is_text_extension(".md") is True
|
|
assert _is_text_extension(".csv") is True
|
|
assert _is_text_extension(".json") is True
|
|
assert _is_text_extension(".yaml") is True
|
|
assert _is_text_extension(".yml") is True
|
|
assert _is_text_extension(".xml") is True
|
|
assert _is_text_extension(".html") is True
|
|
assert _is_text_extension(".htm") is True
|
|
|
|
def test_non_text_extensions_return_false(self):
|
|
"""Test that non-text extensions return False."""
|
|
assert _is_text_extension(".pdf") is False
|
|
assert _is_text_extension(".docx") is False
|
|
assert _is_text_extension(".xlsx") is False
|
|
assert _is_text_extension(".pptx") is False
|
|
assert _is_text_extension(".png") is False
|
|
assert _is_text_extension(".xyz") is False
|
|
|
|
def test_case_sensitivity(self):
|
|
"""Test that _is_text_extension requires lowercase extension.
|
|
|
|
Note: The main extract_text function handles case-insensitivity by
|
|
converting extensions to lowercase before calling _is_text_extension.
|
|
"""
|
|
# _is_text_extension itself is case-sensitive (lowercase only)
|
|
assert _is_text_extension(".txt") is True
|
|
assert _is_text_extension(".TXT") is False
|
|
assert _is_text_extension(".pdf") is False
|