44import logging
55import os
66import re
7+ from enum import Enum
78from pathlib import Path
89from typing import Any
910
1011from pydantic import BaseModel , Field
1112
12- from raglite ._config import ImageType , MistralOCRConfig
13+ from raglite ._config import MistralOCRConfig
1314
1415logger = logging .getLogger (__name__ )
1516
@@ -32,24 +33,28 @@ class MistralOCRError(Exception):
3233 """Error during MistralOCR processing."""
3334
3435
35- _IMAGE_TYPE_VALUES = ", " .join (t .value for t in ImageType )
36+ def _build_image_annotation_model (image_types : frozenset [str ]) -> type [BaseModel ]:
37+ """Build an ImageAnnotation Pydantic model with the given image types."""
38+ image_type_enum = Enum ("ImageType" , {t .upper (): t for t in sorted (image_types )}, type = str ) # type: ignore[misc]
39+ image_type_values = ", " .join (sorted (image_types ))
3640
41+ class ImageAnnotation (BaseModel ):
42+ """Schema for vision-based image annotation."""
3743
38- class ImageAnnotation (BaseModel ):
39- """Schema for vision-based image annotation."""
44+ image_type : image_type_enum = Field ( # type: ignore[valid-type]
45+ ...,
46+ description = f"The type of the image. Must be one of: { image_type_values } ." ,
47+ )
48+ description : str = Field (
49+ ...,
50+ description = (
51+ "A concise description of the image content. For diagrams and charts, "
52+ "describe what is being illustrated. For tables, summarize the data. "
53+ "For photos, describe the subject matter."
54+ ),
55+ )
4056
41- image_type : ImageType = Field (
42- ...,
43- description = f"The type of the image. Must be one of: { _IMAGE_TYPE_VALUES } ." ,
44- )
45- description : str = Field (
46- ...,
47- description = (
48- "A concise description of the image content. For diagrams and charts, "
49- "describe what is being illustrated. For tables, summarize the data. "
50- "For photos, describe the subject matter."
51- ),
52- )
57+ return ImageAnnotation
5358
5459
5560def _get_api_key (processor_config : MistralOCRConfig ) -> str :
@@ -101,8 +106,9 @@ def _encode_document_base64(doc_path: Path) -> tuple[str, str]:
101106def _process_ocr_response (
102107 ocr_response : Any ,
103108 * ,
109+ annotation_model : type [BaseModel ],
104110 include_image_descriptions : bool = True ,
105- exclude_image_types : frozenset [ImageType ] | None = None ,
111+ exclude_image_types : frozenset [str ] | None = None ,
106112) -> str :
107113 """Convert MistralOCR response to markdown string.
108114
@@ -113,10 +119,12 @@ def _process_ocr_response(
113119 ----------
114120 ocr_response
115121 Response from Mistral OCR API.
122+ annotation_model
123+ The Pydantic model used to parse image annotations.
116124 include_image_descriptions
117125 Whether to replace image placeholders with annotations.
118126 exclude_image_types
119- Set of ImageType values to exclude from output.
127+ Set of image type strings to exclude from output.
120128
121129 Returns
122130 -------
@@ -137,12 +145,13 @@ def _process_ocr_response(
137145 placeholder_pattern = rf"!\[[^\]]*\]\({ re .escape (img .id )} \)"
138146 # Parse annotation to check image type for filtering.
139147 try :
140- parsed = ImageAnnotation .model_validate_json (annotation )
141- if parsed .image_type in exclude_image_types :
148+ parsed : Any = annotation_model .model_validate_json (annotation )
149+ image_type = parsed .image_type .value
150+ if image_type in exclude_image_types :
142151 # Remove the image placeholder entirely.
143152 page_md = re .sub (placeholder_pattern , "" , page_md )
144153 continue
145- replacement = f"[Image ({ parsed . image_type . value } ): { parsed .description } ]"
154+ replacement = f"[Image ({ image_type } ): { parsed .description } ]"
146155 except (ValueError , TypeError ):
147156 # If parsing fails, use raw annotation.
148157 replacement = f"[Image: { annotation } ]"
@@ -201,13 +210,15 @@ def mistral_ocr_to_markdown(doc_path: Path, *, processor_config: MistralOCRConfi
201210 "include_image_base64" : False , # We don't need base64, just annotations.
202211 }
203212
213+ annotation_model = _build_image_annotation_model (processor_config .image_types )
214+
204215 try :
205216 client = _get_mistral_client (processor_config )
206217 # Add bbox annotation format if image descriptions are enabled.
207218 if processor_config .include_image_descriptions :
208219 response_format_from_pydantic_model = _get_response_format_converter ()
209220 ocr_params ["bbox_annotation_format" ] = response_format_from_pydantic_model (
210- ImageAnnotation
221+ annotation_model
211222 )
212223 ocr_response = client .ocr .process (** ocr_params )
213224 except (ImportError , ValueError ):
@@ -219,6 +230,7 @@ def mistral_ocr_to_markdown(doc_path: Path, *, processor_config: MistralOCRConfi
219230 # Process response and replace image placeholders with annotations.
220231 return _process_ocr_response (
221232 ocr_response ,
233+ annotation_model = annotation_model ,
222234 include_image_descriptions = processor_config .include_image_descriptions ,
223235 exclude_image_types = processor_config .exclude_image_types ,
224236 )
0 commit comments