Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/app/v2/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def process(
input_path,
)

# Process file in local directory
entities = script(
# Only entity grouping is needed for the API v2 response
_, entities = script(
input_path=temp_dir,
classifier_name="treebased",
model_path=DEFAULT_TREEBASED_MODEL_PATH,
Expand Down
Binary file added examples/reference_document.pdf
Binary file not shown.
15 changes: 8 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# Load .env and check MLFlow
load_dotenv()
mlflow_tracking = os.getenv("MLFLOW_TRACKING").lower() == "true"
mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True"

if mlflow_tracking:
import mlflow
Expand Down Expand Up @@ -246,7 +246,7 @@ def main(
write_result: bool = False,
explain_model: bool = False,
return_entities: bool = False,
) -> list[ProcessorDocument] | list[ProcessorDocumentEntities]:
) -> list[ProcessorDocument] | tuple[list[ProcessorDocument], list[ProcessorDocumentEntities]]:
"""Run the page classification pipeline on input documents.

Args:
Expand All @@ -259,10 +259,11 @@ def main(
return_entities (bool): If True, return grouped entities instead of per-page results.

Returns:
list[ProcessorDocument] | list[ProcessorDocumentEntities]::
* A list of `ProcessorDocument` containing per-page classifications, or
* A list of `ProcessorDocumentEntities` containing grouped (multi-page) entities
when `return_entities=True`.
list[ProcessorDocument] | tuple[list[ProcessorDocument], list[ProcessorDocumentEntities]]:
* A list of `ProcessorDocument` containing per-page classifications (when `return_entities=False`), or
* A tuple of (list[ProcessorDocument], list[ProcessorDocumentEntities]) containing both
per-page results and grouped entities (when `return_entities=True`).


Raises:
ValueError: If an unsupported classifier is specified.
Expand Down Expand Up @@ -317,7 +318,7 @@ def main(
json.dumps([r.model_dump() for r in entities], indent=4),
encoding="utf-8",
)
return entities
return documents_pages, entities


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/classifiers/classifier_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

logger = logging.getLogger(__name__)
load_dotenv()
mlflow_tracking = os.getenv("MLFLOW_TRACKING").lower() == "true"
mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True"

if mlflow_tracking:
import mlflow
Expand Down
2 changes: 1 addition & 1 deletion src/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from src.page_classes import PageClasses

load_dotenv()
mlflow_tracking = os.getenv("MLFLOW_TRACKING").lower() == "true"
mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True"

if mlflow_tracking:
import mlflow
Expand Down
2 changes: 1 addition & 1 deletion src/models/treebased/model_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
xg_boost_config = read_params("config/xgboost_config.yml")

load_dotenv()
mlflow_tracking = os.getenv("MLFLOW_TRACKING").lower() == "true"
mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True"

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/models/treebased/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)

load_dotenv()
mlflow_tracking = os.getenv("MLFLOW_TRACKING").lower() == "true"
mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True"

if mlflow_tracking:
import mlflow
Expand Down
2 changes: 1 addition & 1 deletion src/models/treebased/train_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger(__name__)

load_dotenv()
mlflow_tracking = os.getenv("MLFLOW_TRACKING").lower() == "true"
mlflow_tracking = os.getenv("MLFLOW_TRACKING") == "True"

if mlflow_tracking:
import mlflow
Expand Down
47 changes: 47 additions & 0 deletions src/scripts/create_reference_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Create a reference document from all available classes."""

from pathlib import Path

import pymupdf

REFERENCE_PAGES = [
"data/single_pages/boreprofile/742_7.pdf",
"data/single_pages/diagram/250_3.pdf",
"data/single_pages/geo_profile/24361_29.pdf",
"data/single_pages/map/7066_11.pdf",
"data/single_pages/section_header/1630_393.pdf",
"data/single_pages/table/27898_16.pdf",
"data/single_pages/text/1062_7.pdf",
"data/single_pages/title_page/440_02_1.pdf",
"data/single_pages/unknown/44179_195.pdf",
]


# Output path for reference document
OUTPUT_PDF = Path("examples") / "reference_document.pdf"


def main() -> None:
# Verify all source files exist.
if any(not Path(path).exists() for path in REFERENCE_PAGES):
raise FileNotFoundError("Make sure REFERENCE_PAGES are present")

# Create empty document to append pages
out_doc = pymupdf.Document()

# Append all pages
for source_path in REFERENCE_PAGES:
src_doc = pymupdf.Document(source_path)
# Each single-page PDF contains exactly one page (page index 0).
out_doc.insert_pdf(src_doc, from_page=0, to_page=0)
src_doc.close()

# Write output document
OUTPUT_PDF.parent.mkdir(parents=True, exist_ok=True)
out_doc.save(str(OUTPUT_PDF))
out_doc.close()
print(f"\nSaved: {OUTPUT_PDF}")


if __name__ == "__main__":
main()
59 changes: 59 additions & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""End-to-end test for document classification and entity grouping."""

import pytest

from main import main as script
from src.constants import DEFAULT_TREEBASED_MODEL_PATH


@pytest.fixture
def reference_document() -> str:
"""File that contains all classes once.

Returns:
str: Path to file.
"""
return "examples/reference_document.pdf"


def test_end_to_end(reference_document: str) -> None:
"""Test main pipeline end to end.

Args:
reference_document (str): Reference document to classify.
"""
# Infer reference document and check output exists
documents_pages, documents_entities = script(
input_path=reference_document,
classifier_name="treebased",
model_path=DEFAULT_TREEBASED_MODEL_PATH,
write_result=False,
return_entities=True,
)
assert documents_pages and len(documents_pages) == 1
assert documents_entities and len(documents_entities) == 1

# Unpack single-document results
document_pages = documents_pages[0]
document_entities = documents_entities[0]

# ---- Check pages output
n_pages = document_pages.metadata.page_count
# All pages appear exactly once
assert len(document_pages.pages) == n_pages
assert len(set(page.page for page in document_pages.pages)) == n_pages

# ---- Check entity output
# Check that all pages are within range and in order
assert all(
0 < entity.page_start <= n_pages and entity.page_start <= entity.page_end
for entity in document_entities.entities
)

# ---- Check coherence page - entity output
# Test 1: Same number of pages
assert document_entities.page_count == n_pages
# Test 2: Test that detected classes are the same after processing
assert set(page.classification for page in document_pages.pages) == set(
entity.classification for entity in document_entities.entities
)