77import boto3
88import pymupdf
99from botocore .exceptions import ClientError
10+ from pydantic import BaseModel , Field
1011
1112from src .classifiers .classifier_types import Classifier , ClassifierTypes
1213from src .classifiers .utils import clean_label , map_string_to_page_class , read_image_bytes
1819logger = logging .getLogger (__name__ )
1920
2021
22+ class PixtralImageSource (BaseModel ):
23+ """Raw bytes payload for an image."""
24+
25+ bytes_ : bytes = Field (alias = "bytes" )
26+
27+
28+ class PixtralImage (BaseModel ):
29+ """Image content block containing its format and raw bytes source."""
30+
31+ format_ : str = Field (alias = "format" )
32+ source : PixtralImageSource
33+
34+
35+ class PixtralMessage (BaseModel ):
36+ """A single content block in a Pixtral conversation, either text or image."""
37+
38+ text : str | None = None
39+ image : PixtralImage | None = None
40+
41+
42+ class PixtralMessageStack (BaseModel ):
43+ """A full conversation turn with a role (e.g. 'user') and a list of content blocks."""
44+
45+ role : str
46+ content : list [PixtralMessage ]
47+
48+
49+ class PixtralResponseOutput (BaseModel ):
50+ """The output field of response, wrapping the assistant message."""
51+
52+ message : PixtralMessageStack
53+
54+
55+ class PixtralResponse (BaseModel ):
56+ """Top-level response, containing the model output."""
57+
58+ output : PixtralResponseOutput
59+
60+
2161class RateLimiter :
2262 """Simple token bucket QPS limiter."""
2363
2464 def __init__ (self , qps : float ):
65+ """Initialise the rate limiter with a target queries-per-second rate.
66+
67+ Args:
68+ qps (float): Maximum number of requests allowed per second.
69+ """
2570 self .qps = max (0.1 , qps )
2671 self .lock = threading .Lock ()
2772 self .tokens = 0.0
2873 self .last = time .monotonic ()
2974
3075 def acquire (self ):
76+ """Block until a token is available, then consume it."""
3177 while True :
3278 with self .lock :
3379 now = time .monotonic ()
@@ -53,23 +99,107 @@ def is_throttle_error(e) -> bool:
5399 return False
54100
55101
56- class PixtralClassifier (Classifier ):
57- """Page Classifier using Pixtral Large."""
102+ class PixtralConnector :
103+ """Low-level client for the Pixtral model.
104+
105+ Handles authentication, rate limiting, and retries with exponential
106+ back-off and full jitter when API throttles requests.
107+ """
58108
59109 def __init__ (
60110 self ,
61111 config : dict ,
62112 aws_config : dict ,
63- fallback_classifier = None ,
64113 ):
65- self .type = ClassifierTypes .PIXTRAL
114+ """Initialise client and rate-limiting settings.
115+
116+ Args:
117+ config (dict): Pixtral configuration dict.
118+ aws_config (dict): AWS settings dict.
119+ """
66120 self .config = config
67- self .prompts_dict = read_params (config ["prompt_path" ])[config ["prompt_version" ]]
68121 self .client = boto3 .client ("bedrock-runtime" , region_name = aws_config ["region" ])
69- self .fallback_classifier = fallback_classifier
70122 self .model_id = aws_config ["model_id" ]
123+ self ._stats = {"throttles" : 0 , "retries" : 0 }
124+ self .qps = config .get ("qps" , 2.0 )
125+ self .max_retries = config .get ("max_retries" , 6 )
126+ self .backoff_base = config .get ("backoff_base" , 0.4 )
127+ self .backoff_cap = config .get ("backoff_cap" , 8.0 )
128+ self ._rl = RateLimiter (self .qps )
129+ self .max_doc_size = self .config ["max_document_size_mb" ] - self .config ["slack_size_mb" ]
130+
131+ def _send_conversation (self , message : PixtralMessageStack , system : PixtralMessage ) -> PixtralResponse :
132+ """Send a single-turn conversation to the Pixtral model.
133+
134+ Args:
135+ message (PixtralMessageStack): The user message stack to send.
136+ system (PixtralMessage): The system prompt message.
137+
138+ Returns:
139+ PixtralResponse: The validated model response.
140+ """
141+ attempt = 0
142+ while True :
143+ self ._rl .acquire () # ensure we dont exceed QPS
144+ try :
145+ answer = self .client .converse (
146+ modelId = self .model_id ,
147+ messages = [message .model_dump (by_alias = True , exclude_none = True )],
148+ system = [system .model_dump (by_alias = True , exclude_none = True )],
149+ inferenceConfig = {
150+ "maxTokens" : self .config .get ("max_tokens" , 5 ),
151+ "temperature" : self .config .get ("temperature" , 0.2 ),
152+ },
153+ )
154+ return PixtralResponse .model_validate (answer )
155+ except ClientError as e :
156+ # Retry on throttling
157+ if is_throttle_error (e ) and attempt < self .max_retries :
158+ delay = min (self .backoff_cap , self .backoff_base * (2 ** attempt ))
159+ # full jitter
160+ delay *= random .uniform (0.5 , 1.5 )
161+ logger .warning (f"Bedrock throttled (attempt { attempt + 1 } /{ self .max_retries } ); sleep { delay :.2f} s" )
162+ time .sleep (delay )
163+ attempt += 1
164+
165+ self ._stats ["retries" ] += 1
166+ if "Throttl" in str (e ):
167+ self ._stats ["throttles" ] += 1
168+ continue
169+ raise # not retryable or out of retries
170+ except Exception :
171+ # Non-ClientError; retry a couple of times
172+ if attempt < 2 :
173+ time .sleep (0.5 * (attempt + 1 ))
174+ attempt += 1
175+ continue
176+ raise
177+
178+
179+ class PixtralClassifier (PixtralConnector , Classifier ):
180+ """Page classifier that uses the Pixtral vision model."""
71181
72- self .system_content = [{"text" : self .prompts_dict ["system_prompt" ]}]
182+ def __init__ (
183+ self ,
184+ config : dict ,
185+ aws_config : dict ,
186+ fallback_classifier : Callable = None ,
187+ ):
188+ """Initialise the classifier, loading prompts and example images.
189+
190+ Args:
191+ config (dict): Pixtral configuration dict.
192+ aws_config (dict): AWS settings dict.
193+ fallback_classifier (Callable): Optional classifier to use when Pixtral
194+ returns an unrecognised label or errors out.
195+ """
196+ # Create connection to remote model
197+ PixtralConnector .__init__ (self , config = config , aws_config = aws_config )
198+
199+ self .type = ClassifierTypes .PIXTRAL
200+ self .prompts_dict = read_params (config ["prompt_path" ])[config ["prompt_version" ]]
201+ self .fallback_classifier = fallback_classifier
202+ self .system_content = PixtralMessage (text = self .prompts_dict ["system_prompt" ])
73203 self .examples_bytes = {
74204 "borehole" : read_image_bytes (config ["borehole_img_path" ]),
75205 "text" : read_image_bytes (config ["text_img_path" ]),
@@ -79,12 +209,6 @@ def __init__(
79209 "diagram" : read_image_bytes (config ["diagram_img_path" ]),
80210 "table" : read_image_bytes (config ["table_img_path" ]),
81211 }
82- self ._stats = {"throttles" : 0 , "retries" : 0 }
83- self .qps = config .get ("qps" , 2.0 )
84- self .max_retries = config .get ("max_retries" , 6 )
85- self .backoff_base = config .get ("backoff_base" , 0.4 )
86- self .backoff_cap = config .get ("backoff_cap" , 8.0 )
87- self ._rl = RateLimiter (self .qps )
88212
89213 def determine_class (
90214 self , page : pymupdf .Page , page_number : int , context_builder : Callable [[], PageContext ] = None , ** kwargs
@@ -102,14 +226,12 @@ def determine_class(
102226 Returns:
103227 PageClasses: The predicted page class.
104228 """
105- max_doc_size = self .config ["max_document_size_mb" ] - self .config ["slack_size_mb" ]
106- image_bytes = get_page_image_bytes (page , page_number , max_mb = max_doc_size )
107-
108- conversation = self ._build_conversation (image_bytes = image_bytes )
229+ image_bytes = get_page_image_bytes (page , max_mb = self .max_doc_size )
230+ message = self ._build_conversation (image_bytes = image_bytes )
109231
110232 try :
111- response = self ._send_conversation (conversation )
112- raw_label = response [ " output" ][ " message" ][ " content" ] [0 ][ " text" ]
233+ response = self ._send_conversation (message = message , system = self . system_content )
234+ raw_label = response . output . message . content [0 ]. text
113235
114236 label = clean_label (raw_label )
115237 category = map_string_to_page_class (label )
@@ -138,52 +260,90 @@ def determine_class(
138260 )
139261 return PageClasses .UNKNOWN
140262
141- def _build_conversation (self , image_bytes : bytes ) -> list [dict ]:
142- content = [
143- {"image" : {"format" : "jpeg" , "source" : {"bytes" : self .examples_bytes [text .strip ("@" )]}}}
144- if text .startswith ("@" ) # @category encodes the image of the category and adds it to the content
145- else {"text" : text }
263+ def _build_conversation (self , image_bytes : bytes ) -> PixtralMessageStack :
264+ """Build the user message containing few-shot examples and the target image.
265+
266+ Args:
267+ image_bytes: Eencoded bytes of the page to classify.
268+
269+ Returns:
270+ PixtralMessageStack: A user turn ready to send.
271+ """
272+ # List of examples for pixtral model
273+ content_examples = [
274+ PixtralMessage (
275+ image = PixtralImage (
276+ format = "jpeg" ,
277+ source = PixtralImageSource (bytes = self .examples_bytes [text .strip ("@" )]),
278+ )
279+ )
280+ if text .startswith ("@" )
281+ else PixtralMessage (text = text )
146282 for text in self .prompts_dict .get ("examples_prompt" , [])
147283 ]
148- content .append ({"text" : self .prompts_dict ["user_prompt" ]})
149- content .append ({"image" : {"format" : "jpeg" , "source" : {"bytes" : image_bytes }}})
150284
151- return [{"role" : "user" , "content" : content }]
285+ # User prompt with content to classify
286+ content_user_text = PixtralMessage (text = self .prompts_dict ["user_prompt" ])
287+ content_user_img = PixtralMessage (
288+ image = PixtralImage (
289+ format = "jpeg" ,
290+ source = PixtralImageSource (bytes = image_bytes ),
291+ ),
292+ )
152293
153- def _send_conversation (self , conversation : list ) -> dict :
154- """Sends the conversation to Bedrock with retry-on-throttle."""
155- attempt = 0
156- while True :
157- self ._rl .acquire () # ensure we dont exceed QPS
158- try :
159- return self .client .converse (
160- modelId = self .model_id ,
161- messages = conversation ,
162- system = self .system_content ,
163- inferenceConfig = {
164- "maxTokens" : self .config .get ("max_tokens" , 5 ),
165- "temperature" : self .config .get ("temperature" , 0.2 ),
166- },
294+ return PixtralMessageStack (role = "user" , content = content_examples + [content_user_text , content_user_img ])
295+
296+
297+ class PixtralFeatureExtraction (PixtralConnector ):
298+ """Uses the Pixtral vision model to extract features from PDF pages."""
299+
300+ def __init__ (self , config : dict , aws_config : dict , system_prompt : str ):
301+ """Initialise the extractor with a custom system prompt.
302+
303+ Args:
304+ config (dict): Pixtral configuration dict.
305+ aws_config (dict): AWS settings dict.
306+ system_prompt (Callable): Instruction text sent as the system message for
307+ every extraction request.
308+ """
309+ # Create connection to remote model
310+ PixtralConnector .__init__ (self , config = config , aws_config = aws_config )
311+ self .system_prompt = PixtralMessage (text = system_prompt )
312+
313+ def _build_conversation (self , image_bytes : bytes ) -> PixtralMessageStack :
314+ """Build a minimal user message containing only the target page image.
315+
316+ Args:
317+ image_bytes (bytes): Encoded bytes of the page to process.
318+
319+ Returns:
320+ PixtralMessageStack: A 'user' turn with a single image content block.
321+ """
322+ # List of examples for pixtral model
323+ return PixtralMessageStack (
324+ role = "user" ,
325+ content = [
326+ PixtralMessage (
327+ image = PixtralImage (
328+ format = "jpeg" ,
329+ source = PixtralImageSource (bytes = image_bytes ),
330+ )
167331 )
168- except ClientError as e :
169- # Retry on throttling
170- if is_throttle_error (e ) and attempt < self .max_retries :
171- delay = min (self .backoff_cap , self .backoff_base * (2 ** attempt ))
172- # full jitter
173- delay *= random .uniform (0.5 , 1.5 )
174- logger .warning (f"Bedrock throttled (attempt { attempt + 1 } /{ self .max_retries } ); sleep { delay :.2f} s" )
175- time .sleep (delay )
176- attempt += 1
332+ ],
333+ )
177334
178- self ._stats ["retries" ] += 1
179- if "Throttl" in str (e ):
180- self ._stats ["throttles" ] += 1
181- continue
182- raise # not retryable or out of retries
183- except Exception :
184- # Non-ClientError; retry a couple of times
185- if attempt < 2 :
186- time .sleep (0.5 * (attempt + 1 ))
187- attempt += 1
188- continue
189- raise
335+ def find (self , page : pymupdf .Page ) -> str :
336+ """Extract a feature from a single PDF page using the Pixtral model.
337+
338+ Args:
339+ page (pymupdf.Page): The PyMuPDF page object to process.
340+
341+ Returns:
342+ str: The raw text returned by the model (e.g. an extracted title).
343+ """
344+ # User prompt with content to classify
345+ image_bytes = get_page_image_bytes (page , max_mb = self .max_doc_size )
346+ content_user = self ._build_conversation (image_bytes = image_bytes )
347+
348+ response = self ._send_conversation (message = content_user , system = self .system_prompt )
349+ return response .output .message .content [0 ].text
0 commit comments