Skip to content

Commit 7304b55

Browse files
base script with connection to pixtral to extract pages
1 parent 467b8b5 commit 7304b55

File tree

7 files changed

+428
-66
lines changed

7 files changed

+428
-66
lines changed

config/pixtral_config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ prompt_version: v1.2 # current best is v1.2
1414
max_document_size_mb: 4.5
1515
slack_size_mb: 0.2
1616
# conversation generation
17-
max_tokens: 5
17+
max_tokens: 100
1818
temperature: 0.2
1919
# throttling
2020
qps: 3.0
2121
max_retries: 6
2222
backoff_base: 0.4
23-
backoff_cap: 8.0
23+
backoff_cap: 8.0
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
v0:
2+
system_prompt: |
3+
You are an expert in document layout analysis and content structure recognition.
4+
Your task is to identify and extract the main title of the provided document.
5+
6+
# Guidelines for identifying the title:
7+
8+
- The title is typically the most prominent text element — largest font, bold, or positioned at the center or top of the document
9+
- It represents the primary subject or name of the document, not an author name, or company name
10+
- If multiple candidate titles exist, select the one that best represents the document's overall topic
11+
12+
# Output instructions:
13+
14+
Return the title text only — no explanations, punctuation wrappers, quotes, or extra formatting
15+
Preserve the original capitalization of the title

src/classifiers/pixtral_classifier.py

Lines changed: 222 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import boto3
88
import pymupdf
99
from botocore.exceptions import ClientError
10+
from pydantic import BaseModel, Field
1011

1112
from src.classifiers.classifier_types import Classifier, ClassifierTypes
1213
from src.classifiers.utils import clean_label, map_string_to_page_class, read_image_bytes
@@ -18,16 +19,61 @@
1819
logger = logging.getLogger(__name__)
1920

2021

22+
class PixtralImageSource(BaseModel):
23+
"""Raw bytes payload for an image."""
24+
25+
bytes_: bytes = Field(alias="bytes")
26+
27+
28+
class PixtralImage(BaseModel):
29+
"""Image content block containing its format and raw bytes source."""
30+
31+
format_: str = Field(alias="format")
32+
source: PixtralImageSource
33+
34+
35+
class PixtralMessage(BaseModel):
36+
"""A single content block in a Pixtral conversation, either text or image."""
37+
38+
text: str | None = None
39+
image: PixtralImage | None = None
40+
41+
42+
class PixtralMessageStack(BaseModel):
43+
"""A full conversation turn with a role (e.g. 'user') and a list of content blocks."""
44+
45+
role: str
46+
content: list[PixtralMessage]
47+
48+
49+
class PixtralResponseOutput(BaseModel):
50+
"""The output field of response, wrapping the assistant message."""
51+
52+
message: PixtralMessageStack
53+
54+
55+
class PixtralResponse(BaseModel):
56+
"""Top-level response, containing the model output."""
57+
58+
output: PixtralResponseOutput
59+
60+
2161
class RateLimiter:
2262
"""Simple token bucket QPS limiter."""
2363

2464
def __init__(self, qps: float):
65+
"""Initialise the rate limiter with a target queries-per-second rate.
66+
67+
Args:
68+
qps (float): Maximum number of requests allowed per second.
69+
"""
2570
self.qps = max(0.1, qps)
2671
self.lock = threading.Lock()
2772
self.tokens = 0.0
2873
self.last = time.monotonic()
2974

3075
def acquire(self):
76+
"""Block until a token is available, then consume it."""
3177
while True:
3278
with self.lock:
3379
now = time.monotonic()
@@ -53,23 +99,107 @@ def is_throttle_error(e) -> bool:
5399
return False
54100

55101

56-
class PixtralClassifier(Classifier):
57-
"""Page Classifier using Pixtral Large."""
102+
class PixtralConnector:
103+
"""Low-level client for the Pixtral model.
104+
105+
Handles authentication, rate limiting, and retries with exponential
106+
back-off and full jitter when API throttles requests.
107+
"""
58108

59109
def __init__(
60110
self,
61111
config: dict,
62112
aws_config: dict,
63-
fallback_classifier=None,
64113
):
65-
self.type = ClassifierTypes.PIXTRAL
114+
"""Initialise client and rate-limiting settings.
115+
116+
Args:
117+
config (dict): Pixtral configuration dict.
118+
aws_config (dict): AWS settings dict.
119+
"""
66120
self.config = config
67-
self.prompts_dict = read_params(config["prompt_path"])[config["prompt_version"]]
68121
self.client = boto3.client("bedrock-runtime", region_name=aws_config["region"])
69-
self.fallback_classifier = fallback_classifier
70122
self.model_id = aws_config["model_id"]
123+
self._stats = {"throttles": 0, "retries": 0}
124+
self.qps = config.get("qps", 2.0)
125+
self.max_retries = config.get("max_retries", 6)
126+
self.backoff_base = config.get("backoff_base", 0.4)
127+
self.backoff_cap = config.get("backoff_cap", 8.0)
128+
self._rl = RateLimiter(self.qps)
129+
self.max_doc_size = self.config["max_document_size_mb"] - self.config["slack_size_mb"]
130+
131+
def _send_conversation(self, message: PixtralMessageStack, system: PixtralMessage) -> PixtralResponse:
132+
"""Send a single-turn conversation to the Pixtral model.
133+
134+
Args:
135+
message (PixtralMessageStack): The user message stack to send.
136+
system (PixtralMessage): The system prompt message.
137+
138+
Returns:
139+
PixtralResponse: The validated model response.
140+
"""
141+
attempt = 0
142+
while True:
143+
self._rl.acquire() # ensure we dont exceed QPS
144+
try:
145+
answer = self.client.converse(
146+
modelId=self.model_id,
147+
messages=[message.model_dump(by_alias=True, exclude_none=True)],
148+
system=[system.model_dump(by_alias=True, exclude_none=True)],
149+
inferenceConfig={
150+
"maxTokens": self.config.get("max_tokens", 5),
151+
"temperature": self.config.get("temperature", 0.2),
152+
},
153+
)
154+
return PixtralResponse.model_validate(answer)
155+
except ClientError as e:
156+
# Retry on throttling
157+
if is_throttle_error(e) and attempt < self.max_retries:
158+
delay = min(self.backoff_cap, self.backoff_base * (2**attempt))
159+
# full jitter
160+
delay *= random.uniform(0.5, 1.5)
161+
logger.warning(f"Bedrock throttled (attempt {attempt + 1}/{self.max_retries}); sleep {delay:.2f}s")
162+
time.sleep(delay)
163+
attempt += 1
164+
165+
self._stats["retries"] += 1
166+
if "Throttl" in str(e):
167+
self._stats["throttles"] += 1
168+
continue
169+
raise # not retryable or out of retries
170+
except Exception:
171+
# Non-ClientError; retry a couple of times
172+
if attempt < 2:
173+
time.sleep(0.5 * (attempt + 1))
174+
attempt += 1
175+
continue
176+
raise
177+
178+
179+
class PixtralClassifier(PixtralConnector, Classifier):
180+
"""Page classifier that uses the Pixtral vision model."""
71181

72-
self.system_content = [{"text": self.prompts_dict["system_prompt"]}]
182+
def __init__(
183+
self,
184+
config: dict,
185+
aws_config: dict,
186+
fallback_classifier: Callable = None,
187+
):
188+
"""Initialise the classifier, loading prompts and example images.
189+
190+
Args:
191+
config (dict): Pixtral configuration dict.
192+
aws_config (dict): AWS settings dict.
193+
fallback_classifier (Callable): Optional classifier to use when Pixtral
194+
returns an unrecognised label or errors out.
195+
"""
196+
# Create connection to remote model
197+
PixtralConnector.__init__(self, config=config, aws_config=aws_config)
198+
199+
self.type = ClassifierTypes.PIXTRAL
200+
self.prompts_dict = read_params(config["prompt_path"])[config["prompt_version"]]
201+
self.fallback_classifier = fallback_classifier
202+
self.system_content = PixtralMessage(text=self.prompts_dict["system_prompt"])
73203
self.examples_bytes = {
74204
"borehole": read_image_bytes(config["borehole_img_path"]),
75205
"text": read_image_bytes(config["text_img_path"]),
@@ -79,12 +209,6 @@ def __init__(
79209
"diagram": read_image_bytes(config["diagram_img_path"]),
80210
"table": read_image_bytes(config["table_img_path"]),
81211
}
82-
self._stats = {"throttles": 0, "retries": 0}
83-
self.qps = config.get("qps", 2.0)
84-
self.max_retries = config.get("max_retries", 6)
85-
self.backoff_base = config.get("backoff_base", 0.4)
86-
self.backoff_cap = config.get("backoff_cap", 8.0)
87-
self._rl = RateLimiter(self.qps)
88212

89213
def determine_class(
90214
self, page: pymupdf.Page, page_number: int, context_builder: Callable[[], PageContext] = None, **kwargs
@@ -102,14 +226,12 @@ def determine_class(
102226
Returns:
103227
PageClasses: The predicted page class.
104228
"""
105-
max_doc_size = self.config["max_document_size_mb"] - self.config["slack_size_mb"]
106-
image_bytes = get_page_image_bytes(page, page_number, max_mb=max_doc_size)
107-
108-
conversation = self._build_conversation(image_bytes=image_bytes)
229+
image_bytes = get_page_image_bytes(page, max_mb=self.max_doc_size)
230+
message = self._build_conversation(image_bytes=image_bytes)
109231

110232
try:
111-
response = self._send_conversation(conversation)
112-
raw_label = response["output"]["message"]["content"][0]["text"]
233+
response = self._send_conversation(message=message, system=self.system_content)
234+
raw_label = response.output.message.content[0].text
113235

114236
label = clean_label(raw_label)
115237
category = map_string_to_page_class(label)
@@ -138,52 +260,90 @@ def determine_class(
138260
)
139261
return PageClasses.UNKNOWN
140262

141-
def _build_conversation(self, image_bytes: bytes) -> list[dict]:
142-
content = [
143-
{"image": {"format": "jpeg", "source": {"bytes": self.examples_bytes[text.strip("@")]}}}
144-
if text.startswith("@") # @category encodes the image of the category and adds it to the content
145-
else {"text": text}
263+
def _build_conversation(self, image_bytes: bytes) -> PixtralMessageStack:
264+
"""Build the user message containing few-shot examples and the target image.
265+
266+
Args:
267+
image_bytes: Eencoded bytes of the page to classify.
268+
269+
Returns:
270+
PixtralMessageStack: A user turn ready to send.
271+
"""
272+
# List of examples for pixtral model
273+
content_examples = [
274+
PixtralMessage(
275+
image=PixtralImage(
276+
format="jpeg",
277+
source=PixtralImageSource(bytes=self.examples_bytes[text.strip("@")]),
278+
)
279+
)
280+
if text.startswith("@")
281+
else PixtralMessage(text=text)
146282
for text in self.prompts_dict.get("examples_prompt", [])
147283
]
148-
content.append({"text": self.prompts_dict["user_prompt"]})
149-
content.append({"image": {"format": "jpeg", "source": {"bytes": image_bytes}}})
150284

151-
return [{"role": "user", "content": content}]
285+
# User prompt with content to classify
286+
content_user_text = PixtralMessage(text=self.prompts_dict["user_prompt"])
287+
content_user_img = PixtralMessage(
288+
image=PixtralImage(
289+
format="jpeg",
290+
source=PixtralImageSource(bytes=image_bytes),
291+
),
292+
)
152293

153-
def _send_conversation(self, conversation: list) -> dict:
154-
"""Sends the conversation to Bedrock with retry-on-throttle."""
155-
attempt = 0
156-
while True:
157-
self._rl.acquire() # ensure we dont exceed QPS
158-
try:
159-
return self.client.converse(
160-
modelId=self.model_id,
161-
messages=conversation,
162-
system=self.system_content,
163-
inferenceConfig={
164-
"maxTokens": self.config.get("max_tokens", 5),
165-
"temperature": self.config.get("temperature", 0.2),
166-
},
294+
return PixtralMessageStack(role="user", content=content_examples + [content_user_text, content_user_img])
295+
296+
297+
class PixtralFeatureExtraction(PixtralConnector):
298+
"""Uses the Pixtral vision model to extract features from PDF pages."""
299+
300+
def __init__(self, config: dict, aws_config: dict, system_prompt: str):
301+
"""Initialise the extractor with a custom system prompt.
302+
303+
Args:
304+
config (dict): Pixtral configuration dict.
305+
aws_config (dict): AWS settings dict.
306+
system_prompt (Callable): Instruction text sent as the system message for
307+
every extraction request.
308+
"""
309+
# Create connection to remote model
310+
PixtralConnector.__init__(self, config=config, aws_config=aws_config)
311+
self.system_prompt = PixtralMessage(text=system_prompt)
312+
313+
def _build_conversation(self, image_bytes: bytes) -> PixtralMessageStack:
314+
"""Build a minimal user message containing only the target page image.
315+
316+
Args:
317+
image_bytes (bytes): Encoded bytes of the page to process.
318+
319+
Returns:
320+
PixtralMessageStack: A 'user' turn with a single image content block.
321+
"""
322+
# List of examples for pixtral model
323+
return PixtralMessageStack(
324+
role="user",
325+
content=[
326+
PixtralMessage(
327+
image=PixtralImage(
328+
format="jpeg",
329+
source=PixtralImageSource(bytes=image_bytes),
330+
)
167331
)
168-
except ClientError as e:
169-
# Retry on throttling
170-
if is_throttle_error(e) and attempt < self.max_retries:
171-
delay = min(self.backoff_cap, self.backoff_base * (2**attempt))
172-
# full jitter
173-
delay *= random.uniform(0.5, 1.5)
174-
logger.warning(f"Bedrock throttled (attempt {attempt + 1}/{self.max_retries}); sleep {delay:.2f}s")
175-
time.sleep(delay)
176-
attempt += 1
332+
],
333+
)
177334

178-
self._stats["retries"] += 1
179-
if "Throttl" in str(e):
180-
self._stats["throttles"] += 1
181-
continue
182-
raise # not retryable or out of retries
183-
except Exception:
184-
# Non-ClientError; retry a couple of times
185-
if attempt < 2:
186-
time.sleep(0.5 * (attempt + 1))
187-
attempt += 1
188-
continue
189-
raise
335+
def find(self, page: pymupdf.Page) -> str:
336+
"""Extract a feature from a single PDF page using the Pixtral model.
337+
338+
Args:
339+
page (pymupdf.Page): The PyMuPDF page object to process.
340+
341+
Returns:
342+
str: The raw text returned by the model (e.g. an extracted title).
343+
"""
344+
# User prompt with content to classify
345+
image_bytes = get_page_image_bytes(page, max_mb=self.max_doc_size)
346+
content_user = self._build_conversation(image_bytes=image_bytes)
347+
348+
response = self._send_conversation(message=content_user, system=self.system_prompt)
349+
return response.output.message.content[0].text

src/page_graphics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def get_images_from_page(page: pymupdf.Page) -> list[ImageRect]:
5151
return extracted_images
5252

5353

54-
def get_page_image_bytes(page: pymupdf.Page, page_number: int, max_mb: float = 4.5) -> bytes:
54+
def get_page_image_bytes(page: pymupdf.Page, max_mb: float = 4.5) -> bytes:
5555
"""Returns JPEG image bytes of a single PDF page. Downscales if image exceeds allowed size."""
5656
max_bytes = int(max_mb * 1024 * 1024)
5757
scale = 1.0
5858

5959
for attempt in range(10):
6060
# Render and convert to JPEG
6161
with pymupdf.open() as pdf_doc:
62-
pdf_doc.insert_pdf(page.parent, from_page=page_number, to_page=page_number)
62+
pdf_doc.insert_pdf(page.parent, from_page=0, to_page=0)
6363
page_bytes = pdf_doc.tobytes(deflate=True, garbage=3, use_objstms=1)
6464

6565
image_bytes = convert_pdf_to_jpeg(page_bytes, scale=scale)

0 commit comments

Comments
 (0)