Skip to content

Commit 2c636e3

Browse files
Merge pull request #93 from swisstopo/feat/issue-87/return-document-title
Add document title extraction
2 parents 419f0c2 + 0c1fbf5 commit 2c636e3

File tree

8 files changed

+260
-52
lines changed

8 files changed

+260
-52
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,10 @@ The output of the pipeline is dependent of the version queried.
9898
"entities": [ // List of elements present in file
9999
{
100100
"classification": "boreprofile", // Type of element (PageClasses)
101+
"language": "de", // Detected language
101102
"page_start": 1, // Starting page
102103
"page_end": 3, // Ending page
103-
"language": "de" // Detected language
104+
"title": "BS1" // Entity title (None if not found)
104105
}
105106
]
106107
}

api/app/v2/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ class CollectResponse(BaseModel):
2020
"entities": [
2121
{
2222
"classification": "boreprofile",
23+
"language": "de",
2324
"page_start": 1,
2425
"page_end": 3,
25-
"language": "de",
26+
"title": "BS1",
2627
},
2728
],
2829
},

config/local_matching_params.yml

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
table_of_contents:
32
en:
43
- table of contents
@@ -33,20 +32,20 @@ caption_description:
3332
boreprofile:
3433
en:
3534
must_contain:
36-
- [ profile, log, lithostratigraphy ]
37-
- [ borehole, drilling, bore ]
35+
- [profile, log, lithostratigraphy]
36+
- [borehole, drilling, bore]
3837
de:
3938
must_contain:
40-
- [ profil, log, lithostratigraphie]
41-
- [ bohrung, bohrloch, bohr ]
39+
- [profil, log, lithostratigraphie]
40+
- [bohrung, bohrloch, bohr]
4241
fr:
4342
must_contain:
44-
- [ profil, log, lithostratigraphique ]
45-
- [ forage, sondage ]
43+
- [profil, log, lithostratigraphique]
44+
- [forage, sondage]
4645
it:
4746
must_contain:
48-
- [ profilo, log, stratigrafico, diagramma ]
49-
- [ perforazione, sondaggio ]
47+
- [profilo, log, stratigrafico, diagramma]
48+
- [perforazione, sondaggio]
5049

5150
boreprofile:
5251
en:
@@ -173,7 +172,6 @@ geo_profile:
173172
- sezione verticale
174173
- taglio geologico
175174

176-
177175
diagram:
178176
en:
179177
- diagram
@@ -245,4 +243,4 @@ open_ended_depth_key:
245243
- à partir de
246244
- from
247245
- starting at
248-
- a partire
246+
- a partire

main.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from dotenv import load_dotenv
99
from swissgeol_doc_processing.utils.file_utils import read_params as swissgeol_read_params
1010

11+
from src.boreprofile.entity_parser import document_to_boreprofiles
1112
from src.classifiers.classifier_factory import ClassifierTypes, create_classifier
1213
from src.constants import DEFAULT_TREEBASED_MODEL_PATH
14+
from src.page_classes import PageClasses
1315
from src.page_structure import (
1416
ProcessedEntities,
1517
ProcessorDocument,
@@ -92,7 +94,7 @@ def forward_document(
9294
"""Infer document classes.
9395
9496
Args:
95-
pdf_files (list[Path]): List fo documents to classify.
97+
pdf_files (list[Path]): List of documents to classify.
9698
matching_params (dict): Dict of parameters for document processing.
9799
borehole_matching_params (dict): Dict of parameters for borehole matching.
98100
model_path (str, optional): Path to pretrained model.
@@ -114,13 +116,45 @@ def forward_document(
114116
return processor.process_batch(pdf_files)
115117

116118

119+
def forward_document_entities_group(
120+
classification: PageClasses,
121+
page_start: int,
122+
page_end: int,
123+
language: str | None,
124+
pdf_file: Path,
125+
) -> list[ProcessedEntities]:
126+
"""Extract entities from a group of consecutive pages with the same classification.
127+
128+
Args:
129+
classification (PageClasses): The classification type of the page group.
130+
page_start (int): First page index in the consecutive group (1-based).
131+
page_end (int): Last page index in the consecutive group (1-based).
132+
language (str): Detected language of the page group.
133+
pdf_file (Path): Path to the source PDF file.
134+
135+
Returns:
136+
list[ProcessedEntities]: Extracted entities from the page group.
137+
"""
138+
if classification == PageClasses.BOREPROFILE:
139+
return document_to_boreprofiles(pdf_file=pdf_file, page_start=page_start, page_end=page_end, lang=language)
140+
else:
141+
return [
142+
ProcessedEntities(
143+
classification=classification,
144+
page_start=page_start,
145+
page_end=page_end,
146+
language=language,
147+
)
148+
]
149+
150+
117151
def forward_document_entities(
118152
documents: list[ProcessorDocument],
119153
) -> list[ProcessorDocumentEntities]:
120154
"""Convert classified documents pages to entities.
121155
122156
Args:
123-
documents (list[ProcessorDocument]): List of documents to process.
157+
documents (list[ProcessorDocument]): List of documents to convert to entities.
124158
125159
Returns:
126160
list[ProcessorDocumentEntities]: Processed documents entities
@@ -132,19 +166,20 @@ def forward_document_entities(
132166
# Iterate over grouped entities types
133167
for (pages_type, lang), pages in document.group_pages_by_type():
134168
# Get pages sequences
135-
pages_id = sorted([page.page for page in pages])
136-
results_entities.extend(
137-
[
138-
ProcessedEntities(
139-
classification=pages_type,
140-
page_start=min(pages_group),
141-
page_end=max(pages_group),
142-
language=lang,
143-
)
144-
# Group consecutive [1,2,10] -> [1,2], [10]
145-
for pages_group in group_consecutive(pages_id)
146-
]
147-
)
169+
page_numbers = sorted([page.page for page in pages])
170+
entities = [
171+
entity
172+
for pages_group in group_consecutive(page_numbers) # Group consecutive [1,2,10] -> [1,2], [10]
173+
for entity in forward_document_entities_group(
174+
classification=pages_type,
175+
page_start=min(pages_group),
176+
page_end=max(pages_group),
177+
language=lang,
178+
pdf_file=document.path,
179+
)
180+
]
181+
# Extend entry
182+
results_entities.extend(entities)
148183
# Create document from filename, metadata, entities
149184
documents_entities.append(
150185
ProcessorDocumentEntities(
@@ -166,20 +201,20 @@ def main(
166201
write_result: bool = False,
167202
explain_model: bool = False,
168203
return_entities: bool = False,
169-
) -> tuple[list[ProcessorDocument] | list[ProcessorDocumentEntities]]:
204+
) -> list[ProcessorDocument] | list[ProcessorDocumentEntities]:
170205
"""Run the page classification pipeline on input documents.
171206
172207
Args:
173208
input_path (str): Path to directory with PDF pages or documents.
174209
ground_truth_path (str, optional): Path to ground truth JSON file for evaluation.
175210
model_path (str, optional): Path to pretrained model.
176211
classifier_name (str, optional): Classifier to use ("treebased", "pixtral", etc.).
177-
write_result (bool): If True, writes results to prediction.json.
212+
write_result (bool): If True, and return_entities is True, writes results to prediction.json.
178213
explain_model (bool): If True, generates plots to explain the model's choices.
179214
return_entities (bool): If True, return grouped entities instead of per-page results.
180215
181-
Return:
182-
tuple[list[ProcessorDocument] | list[ProcessorDocumentEntities]]:
216+
Returns:
217+
list[ProcessorDocument] | list[ProcessorDocumentEntities]::
183218
* A list of `ProcessorDocument` containing per-page classifications, or
184219
* A list of `ProcessorDocumentEntities` containing grouped (multi-page) entities
185220
when `return_entities=True`.
@@ -200,7 +235,7 @@ def main(
200235
pdf_files = get_pdf_files(input_path)
201236
if not pdf_files:
202237
logger.error("No valid PDFs found.")
203-
return [], []
238+
return []
204239

205240
# Run individual page classification
206241
documents_pages = forward_document(
@@ -212,15 +247,6 @@ def main(
212247
explain_model=explain_model,
213248
)
214249

215-
# Check if data need to be saved
216-
if write_result:
217-
output_file = Path("data") / "prediction.json"
218-
output_file.parent.mkdir(parents=True, exist_ok=True)
219-
output_file.write_text(
220-
json.dumps([r.model_dump() for r in documents_pages], indent=4),
221-
encoding="utf-8",
222-
)
223-
224250
# Check if GT need to be computed
225251
if ground_truth_path:
226252
from src.evaluation import evaluate_results
@@ -236,7 +262,17 @@ def main(
236262
if not return_entities:
237263
return documents_pages
238264
else:
239-
return forward_document_entities(documents=documents_pages)
265+
entities = forward_document_entities(documents=documents_pages)
266+
267+
# Check if data needs to be saved
268+
if write_result:
269+
output_file = Path("data") / "prediction.json"
270+
output_file.parent.mkdir(parents=True, exist_ok=True)
271+
output_file.write_text(
272+
json.dumps([r.model_dump() for r in entities], indent=4),
273+
encoding="utf-8",
274+
)
275+
return entities
240276

241277

242278
if __name__ == "__main__":
@@ -301,4 +337,5 @@ def main(
301337
classifier_name=args.classifier,
302338
write_result=args.write_results,
303339
explain_model=args.explain_model,
340+
return_entities=True,
304341
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ requires-python = ">=3.11,<3.14"
1111

1212
# Production-only dependencies for API runtime
1313
dependencies = [
14-
"swissgeol-boreholes-dataextraction @ https://github.com/swisstopo/swissgeol-boreholes-dataextraction/releases/download/v1.0.138/swissgeol_boreholes_dataextraction-1.0.138-py3-none-any.whl",
14+
"swissgeol-boreholes-dataextraction @ https://github.com/swisstopo/swissgeol-boreholes-dataextraction/releases/download/v1.0.143/swissgeol_boreholes_dataextraction-1.0.143-py3-none-any.whl",
1515
"boto3==1.40.12",
1616
"fastapi==0.116.1",
1717
"boto3-stubs[s3]>=1.40.12,<2.0.0",

0 commit comments

Comments
 (0)