Skip to content

Commit b3b43fe

Browse files
committed
feat: allow custom image type categories in MistralOCRConfig
1 parent 724f066 commit b3b43fe

4 files changed

Lines changed: 47 additions & 41 deletions

File tree

src/raglite/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""RAGLite."""
22

3-
from raglite._config import ImageType, MistralOCRConfig, RAGLiteConfig
3+
from raglite._config import MistralOCRConfig, RAGLiteConfig
44
from raglite._database import Document
55
from raglite._delete import delete_documents, delete_documents_by_metadata
66
from raglite._eval import answer_evals, evaluate, insert_evals
@@ -25,7 +25,6 @@
2525
"RAGLiteConfig",
2626
"MistralOCRConfig",
2727
"MistralOCRError",
28-
"ImageType",
2928
# Insert
3029
"Document",
3130
"insert_documents",

src/raglite/_config.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import contextlib
44
import os
55
from dataclasses import dataclass, field
6-
from enum import Enum
76
from io import StringIO
87
from pathlib import Path
98
from typing import Literal
@@ -24,18 +23,9 @@
2423
cache_path = Path(user_data_dir("raglite", ensure_exists=True))
2524

2625

27-
class ImageType(str, Enum):
28-
"""Type of image detected by OCR."""
29-
30-
GRAPH = "graph"
31-
CHART = "chart"
32-
DIAGRAM = "diagram"
33-
TABLE = "table"
34-
PHOTO = "photo"
35-
SCREENSHOT = "screenshot"
36-
LOGO = "logo"
37-
ICON = "icon"
38-
OTHER = "other"
26+
DEFAULT_IMAGE_TYPES = frozenset(
27+
{"graph", "chart", "diagram", "table", "photo", "screenshot", "logo", "icon", "other"}
28+
)
3929

4030

4131
@dataclass(frozen=True)
@@ -46,8 +36,10 @@ class MistralOCRConfig:
4636
api_key: str | None = None
4737
# Whether to use vision to describe images in documents.
4838
include_image_descriptions: bool = True
49-
# Image types to exclude from processing (e.g., {ImageType.LOGO, ImageType.ICON}).
50-
exclude_image_types: frozenset[ImageType] = frozenset()
39+
# Image types that Mistral classifies each image into.
40+
image_types: frozenset[str] = DEFAULT_IMAGE_TYPES
41+
# Image types to exclude from the output (e.g., {"logo", "icon"}).
42+
exclude_image_types: frozenset[str] = frozenset()
5143
model: str = "mistral-ocr-latest"
5244

5345

src/raglite/_mistral_ocr.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import logging
55
import os
66
import re
7+
from enum import Enum
78
from pathlib import Path
89
from typing import Any
910

1011
from pydantic import BaseModel, Field
1112

12-
from raglite._config import ImageType, MistralOCRConfig
13+
from raglite._config import MistralOCRConfig
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -32,24 +33,28 @@ class MistralOCRError(Exception):
3233
"""Error during MistralOCR processing."""
3334

3435

35-
_IMAGE_TYPE_VALUES = ", ".join(t.value for t in ImageType)
36+
def _build_image_annotation_model(image_types: frozenset[str]) -> type[BaseModel]:
37+
"""Build an ImageAnnotation Pydantic model with the given image types."""
38+
image_type_enum = Enum("ImageType", {t.upper(): t for t in sorted(image_types)}, type=str) # type: ignore[misc]
39+
image_type_values = ", ".join(sorted(image_types))
3640

41+
class ImageAnnotation(BaseModel):
42+
"""Schema for vision-based image annotation."""
3743

38-
class ImageAnnotation(BaseModel):
39-
"""Schema for vision-based image annotation."""
44+
image_type: image_type_enum = Field( # type: ignore[valid-type]
45+
...,
46+
description=f"The type of the image. Must be one of: {image_type_values}.",
47+
)
48+
description: str = Field(
49+
...,
50+
description=(
51+
"A concise description of the image content. For diagrams and charts, "
52+
"describe what is being illustrated. For tables, summarize the data. "
53+
"For photos, describe the subject matter."
54+
),
55+
)
4056

41-
image_type: ImageType = Field(
42-
...,
43-
description=f"The type of the image. Must be one of: {_IMAGE_TYPE_VALUES}.",
44-
)
45-
description: str = Field(
46-
...,
47-
description=(
48-
"A concise description of the image content. For diagrams and charts, "
49-
"describe what is being illustrated. For tables, summarize the data. "
50-
"For photos, describe the subject matter."
51-
),
52-
)
57+
return ImageAnnotation
5358

5459

5560
def _get_api_key(processor_config: MistralOCRConfig) -> str:
@@ -101,8 +106,9 @@ def _encode_document_base64(doc_path: Path) -> tuple[str, str]:
101106
def _process_ocr_response(
102107
ocr_response: Any,
103108
*,
109+
annotation_model: type[BaseModel],
104110
include_image_descriptions: bool = True,
105-
exclude_image_types: frozenset[ImageType] | None = None,
111+
exclude_image_types: frozenset[str] | None = None,
106112
) -> str:
107113
"""Convert MistralOCR response to markdown string.
108114
@@ -113,10 +119,12 @@ def _process_ocr_response(
113119
----------
114120
ocr_response
115121
Response from Mistral OCR API.
122+
annotation_model
123+
The Pydantic model used to parse image annotations.
116124
include_image_descriptions
117125
Whether to replace image placeholders with annotations.
118126
exclude_image_types
119-
Set of ImageType values to exclude from output.
127+
Set of image type strings to exclude from output.
120128
121129
Returns
122130
-------
@@ -137,12 +145,13 @@ def _process_ocr_response(
137145
placeholder_pattern = rf"!\[[^\]]*\]\({re.escape(img.id)}\)"
138146
# Parse annotation to check image type for filtering.
139147
try:
140-
parsed = ImageAnnotation.model_validate_json(annotation)
141-
if parsed.image_type in exclude_image_types:
148+
parsed: Any = annotation_model.model_validate_json(annotation)
149+
image_type = parsed.image_type.value
150+
if image_type in exclude_image_types:
142151
# Remove the image placeholder entirely.
143152
page_md = re.sub(placeholder_pattern, "", page_md)
144153
continue
145-
replacement = f"[Image ({parsed.image_type.value}): {parsed.description}]"
154+
replacement = f"[Image ({image_type}): {parsed.description}]"
146155
except (ValueError, TypeError):
147156
# If parsing fails, use raw annotation.
148157
replacement = f"[Image: {annotation}]"
@@ -201,13 +210,15 @@ def mistral_ocr_to_markdown(doc_path: Path, *, processor_config: MistralOCRConfi
201210
"include_image_base64": False, # We don't need base64, just annotations.
202211
}
203212

213+
annotation_model = _build_image_annotation_model(processor_config.image_types)
214+
204215
try:
205216
client = _get_mistral_client(processor_config)
206217
# Add bbox annotation format if image descriptions are enabled.
207218
if processor_config.include_image_descriptions:
208219
response_format_from_pydantic_model = _get_response_format_converter()
209220
ocr_params["bbox_annotation_format"] = response_format_from_pydantic_model(
210-
ImageAnnotation
221+
annotation_model
211222
)
212223
ocr_response = client.ocr.process(**ocr_params)
213224
except (ImportError, ValueError):
@@ -219,6 +230,7 @@ def mistral_ocr_to_markdown(doc_path: Path, *, processor_config: MistralOCRConfi
219230
# Process response and replace image placeholders with annotations.
220231
return _process_ocr_response(
221232
ocr_response,
233+
annotation_model=annotation_model,
222234
include_image_descriptions=processor_config.include_image_descriptions,
223235
exclude_image_types=processor_config.exclude_image_types,
224236
)

tests/test_mistral_ocr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
import pytest
88

9-
from raglite import ImageType, MistralOCRConfig
9+
from raglite import MistralOCRConfig
1010
from raglite._mistral_ocr import (
11+
_build_image_annotation_model,
1112
_process_ocr_response,
1213
mistral_ocr_to_markdown,
1314
)
@@ -42,10 +43,12 @@ def test_process_ocr_response() -> None:
4243
("![](img-r.jpeg)", [("img-r.jpeg", "raw fallback text")]), # page 2
4344
]
4445
)
46+
annotation_model = _build_image_annotation_model(frozenset({"diagram", "logo"}))
4547
result = _process_ocr_response(
4648
response,
49+
annotation_model=annotation_model,
4750
include_image_descriptions=True,
48-
exclude_image_types=frozenset({ImageType.LOGO}),
51+
exclude_image_types=frozenset({"logo"}),
4952
)
5053
assert "[Image (diagram): A flowchart]" in result
5154
assert "Company logo" not in result

0 commit comments

Comments
 (0)