88from dotenv import load_dotenv
99from swissgeol_doc_processing .utils .file_utils import read_params as swissgeol_read_params
1010
11+ from src .boreprofile .entity_parser import document_to_boreprofiles
1112from src .classifiers .classifier_factory import ClassifierTypes , create_classifier
1213from src .constants import DEFAULT_TREEBASED_MODEL_PATH
14+ from src .page_classes import PageClasses
1315from src .page_structure import (
1416 ProcessedEntities ,
1517 ProcessorDocument ,
@@ -92,7 +94,7 @@ def forward_document(
9294 """Infer document classes.
9395
9496 Args:
95- pdf_files (list[Path]): List fo documents to classify.
97+ pdf_files (list[Path]): List of documents to classify.
9698 matching_params (dict): Dict of parameters for document processing.
9799 borehole_matching_params (dict): Dict of parameters for borehole matching.
98100 model_path (str, optional): Path to pretrained model.
@@ -114,13 +116,45 @@ def forward_document(
114116 return processor .process_batch (pdf_files )
115117
116118
119+ def forward_document_entities_group (
120+ classification : PageClasses ,
121+ page_start : int ,
122+ page_end : int ,
123+ language : str | None ,
124+ pdf_file : Path ,
125+ ) -> list [ProcessedEntities ]:
126+ """Extract entities from a group of consecutive pages with the same classification.
127+
128+ Args:
129+ classification (PageClasses): The classification type of the page group.
130+ page_start (int): First page index in the consecutive group (1-based).
131+ page_end (int): Last page index in the consecutive group (1-based).
132+ language (str): Detected language of the page group.
133+ pdf_file (Path): Path to the source PDF file.
134+
135+ Returns:
136+ list[ProcessedEntities]: Extracted entities from the page group.
137+ """
138+ if classification == PageClasses .BOREPROFILE :
139+ return document_to_boreprofiles (pdf_file = pdf_file , page_start = page_start , page_end = page_end , lang = language )
140+ else :
141+ return [
142+ ProcessedEntities (
143+ classification = classification ,
144+ page_start = page_start ,
145+ page_end = page_end ,
146+ language = language ,
147+ )
148+ ]
149+
150+
117151def forward_document_entities (
118152 documents : list [ProcessorDocument ],
119153) -> list [ProcessorDocumentEntities ]:
120154 """Convert classified documents pages to entities.
121155
122156 Args:
123- documents (list[ProcessorDocument]): List of documents to process .
157+ documents (list[ProcessorDocument]): List of documents to convert to entities .
124158
125159 Returns:
126160 list[ProcessorDocumentEntities]: Processed documents entities
@@ -132,19 +166,20 @@ def forward_document_entities(
132166 # Iterate over grouped entities types
133167 for (pages_type , lang ), pages in document .group_pages_by_type ():
134168 # Get pages sequences
135- pages_id = sorted ([page .page for page in pages ])
136- results_entities .extend (
137- [
138- ProcessedEntities (
139- classification = pages_type ,
140- page_start = min (pages_group ),
141- page_end = max (pages_group ),
142- language = lang ,
143- )
144- # Group consecutive [1,2,10] -> [1,2], [10]
145- for pages_group in group_consecutive (pages_id )
146- ]
147- )
169+ page_numbers = sorted ([page .page for page in pages ])
170+ entities = [
171+ entity
172+ for pages_group in group_consecutive (page_numbers ) # Group consecutive [1,2,10] -> [1,2], [10]
173+ for entity in forward_document_entities_group (
174+ classification = pages_type ,
175+ page_start = min (pages_group ),
176+ page_end = max (pages_group ),
177+ language = lang ,
178+ pdf_file = document .path ,
179+ )
180+ ]
181+ # Extend entry
182+ results_entities .extend (entities )
148183 # Create document from filename, metadata, entities
149184 documents_entities .append (
150185 ProcessorDocumentEntities (
@@ -166,20 +201,20 @@ def main(
166201 write_result : bool = False ,
167202 explain_model : bool = False ,
168203 return_entities : bool = False ,
169- ) -> tuple [ list [ProcessorDocument ] | list [ProcessorDocumentEntities ] ]:
204+ ) -> list [ProcessorDocument ] | list [ProcessorDocumentEntities ]:
170205 """Run the page classification pipeline on input documents.
171206
172207 Args:
173208 input_path (str): Path to directory with PDF pages or documents.
174209 ground_truth_path (str, optional): Path to ground truth JSON file for evaluation.
175210 model_path (str, optional): Path to pretrained model.
176211 classifier_name (str, optional): Classifier to use ("treebased", "pixtral", etc.).
177- write_result (bool): If True, writes results to prediction.json.
212+ write_result (bool): If True, and return_entities is True, writes results to prediction.json.
178213 explain_model (bool): If True, generates plots to explain the model's choices.
179214 return_entities (bool): If True, return grouped entities instead of per-page results.
180215
181- Return :
182- tuple[ list[ProcessorDocument] | list[ProcessorDocumentEntities]] :
216+ Returns :
217+ list[ProcessorDocument] | list[ProcessorDocumentEntities]: :
183218 * A list of `ProcessorDocument` containing per-page classifications, or
184219 * A list of `ProcessorDocumentEntities` containing grouped (multi-page) entities
185220 when `return_entities=True`.
@@ -200,7 +235,7 @@ def main(
200235 pdf_files = get_pdf_files (input_path )
201236 if not pdf_files :
202237 logger .error ("No valid PDFs found." )
203- return [], []
238+ return []
204239
205240 # Run individual page classification
206241 documents_pages = forward_document (
@@ -212,15 +247,6 @@ def main(
212247 explain_model = explain_model ,
213248 )
214249
215- # Check if data need to be saved
216- if write_result :
217- output_file = Path ("data" ) / "prediction.json"
218- output_file .parent .mkdir (parents = True , exist_ok = True )
219- output_file .write_text (
220- json .dumps ([r .model_dump () for r in documents_pages ], indent = 4 ),
221- encoding = "utf-8" ,
222- )
223-
224250 # Check if GT need to be computed
225251 if ground_truth_path :
226252 from src .evaluation import evaluate_results
@@ -236,7 +262,17 @@ def main(
236262 if not return_entities :
237263 return documents_pages
238264 else :
239- return forward_document_entities (documents = documents_pages )
265+ entities = forward_document_entities (documents = documents_pages )
266+
267+ # Check if data needs to be saved
268+ if write_result :
269+ output_file = Path ("data" ) / "prediction.json"
270+ output_file .parent .mkdir (parents = True , exist_ok = True )
271+ output_file .write_text (
272+ json .dumps ([r .model_dump () for r in entities ], indent = 4 ),
273+ encoding = "utf-8" ,
274+ )
275+ return entities
240276
241277
242278if __name__ == "__main__" :
@@ -301,4 +337,5 @@ def main(
301337 classifier_name = args .classifier ,
302338 write_result = args .write_results ,
303339 explain_model = args .explain_model ,
340+ return_entities = True ,
304341 )
0 commit comments