Skip to content

Commit 035d78f

Browse files
Merge pull request #1777 from roboflow/feature/add-doctr-prefedined-models
Add tests and fixes for DocTR model
2 parents a3bea9c + 6c45bc6 commit 035d78f

File tree

17 files changed

+477
-88
lines changed

17 files changed

+477
-88
lines changed

inference_experimental/inference_exp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
MultiLabelClassificationPrediction,
1313
)
1414
from inference_exp.models.base.depth_estimation import DepthEstimationModel
15-
from inference_exp.models.base.documents_parsing import DocumentParsingModel
15+
from inference_exp.models.base.documents_parsing import StructuredOCRModel
1616
from inference_exp.models.base.embeddings import TextImageEmbeddingModel
1717
from inference_exp.models.base.instance_segmentation import (
1818
InstanceDetections,

inference_experimental/inference_exp/models/auto_loaders/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
MultiLabelClassificationModel,
5858
)
5959
from inference_exp.models.base.depth_estimation import DepthEstimationModel
60-
from inference_exp.models.base.documents_parsing import DocumentParsingModel
60+
from inference_exp.models.base.documents_parsing import StructuredOCRModel
6161
from inference_exp.models.base.embeddings import TextImageEmbeddingModel
6262
from inference_exp.models.base.instance_segmentation import InstanceSegmentationModel
6363
from inference_exp.models.base.keypoints_detection import KeyPointsDetectionModel
@@ -79,7 +79,7 @@
7979
ClassificationModel,
8080
MultiLabelClassificationModel,
8181
DepthEstimationModel,
82-
DocumentParsingModel,
82+
StructuredOCRModel,
8383
TextImageEmbeddingModel,
8484
InstanceSegmentationModel,
8585
KeyPointsDetectionModel,

inference_experimental/inference_exp/models/auto_loaders/models_registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CLASSIFICATION_TASK = "classification"
1616
MULTI_LABEL_CLASSIFICATION_TASK = "multi-label-classification"
1717
DEPTH_ESTIMATION_TASK = "depth-estimation"
18+
STRUCTURED_OCR_TASK = "structured-ocr"
1819

1920

2021
@dataclass(frozen=True)
@@ -356,8 +357,11 @@ class RegistryEntry:
356357
),
357358
("depth-anything-v2", DEPTH_ESTIMATION_TASK, BackendType.HF): LazyClass(
358359
module_name="inference_exp.models.depth_anything_v2.depth_anything_v2_hf",
359-
class_name="DepthAnythingV2HF"
360-
)
360+
class_name="DepthAnythingV2HF",
361+
),
362+
("doctr", STRUCTURED_OCR_TASK, BackendType.TORCH): LazyClass(
363+
module_name="inference_exp.models.doctr.doctr_torch", class_name="DocTR"
364+
),
361365
}
362366

363367

inference_experimental/inference_exp/models/base/documents_parsing.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111
)
1212

1313

14-
class DocumentParsingModel(
14+
class StructuredOCRModel(
1515
ABC, Generic[PreprocessedInputs, PreprocessingMetadata, RawPrediction]
1616
):
1717

1818
@classmethod
1919
@abstractmethod
20-
def from_pretrained(
21-
cls, model_name_or_path: str, **kwargs
22-
) -> "DocumentParsingModel":
20+
def from_pretrained(cls, model_name_or_path: str, **kwargs) -> "StructuredOCRModel":
2321
pass
2422

2523
@property

inference_experimental/inference_exp/models/depth_anything_v2/depth_anything_v2_hf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ def from_pretrained(
2828
local_files_only=local_files_only,
2929
).to(device)
3030
processor = AutoImageProcessor.from_pretrained(
31-
model_name_or_path,
32-
local_files_only=local_files_only,
33-
use_fast=True
31+
model_name_or_path, local_files_only=local_files_only, use_fast=True
3432
)
3533
return cls(model=model, processor=processor, device=device)
3634

inference_experimental/inference_exp/models/doctr/doctr_torch.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,98 @@
1-
import os
21
from dataclasses import dataclass
32
from typing import Callable, List, Optional, Tuple, Union
43

54
import numpy as np
65
import torch
76
from doctr.io import Document
8-
from doctr.models import ocr_predictor
7+
from doctr.models import detection_predictor, ocr_predictor, recognition_predictor
98
from inference_exp import Detections
109
from inference_exp.configuration import DEFAULT_DEVICE
1110
from inference_exp.entities import ColorFormat, ImageDimensions
1211
from inference_exp.errors import CorruptedModelPackageError, ModelRuntimeError
13-
from inference_exp.models.base.documents_parsing import DocumentParsingModel
12+
from inference_exp.models.base.documents_parsing import StructuredOCRModel
1413
from inference_exp.models.common.model_packages import get_model_package_contents
1514
from inference_exp.utils.file_system import read_json
1615

17-
WEIGHTS_NAMES_MAPPING = {
18-
"db_resnet50": "db_resnet50-79bd7d70.pt",
19-
"db_resnet34": "db_resnet34-cb6aed9e.pt",
20-
"db_mobilenet_v3_large": "db_mobilenet_v3_large-21748dd0.pt",
21-
"crnn_vgg16_bn": "crnn_vgg16_bn-9762b0b0.pt",
22-
"crnn_mobilenet_v3_small": "crnn_mobilenet_v3_small_pt-3b919a02.pt",
23-
"crnn_mobilenet_v3_large": "crnn_mobilenet_v3_large_pt-f5259ec2.pt",
16+
SUPPORTED_DETECTION_MODELS = {
17+
"fast_base",
18+
"fast_small",
19+
"fast_tiny",
20+
"db_resnet50",
21+
"db_resnet34",
22+
"db_mobilenet_v3_large",
23+
"linknet_resnet18",
24+
"linknet_resnet34",
25+
"linknet_resnet50",
26+
}
27+
SUPPORTED_RECOGNITION_MODELS = {
28+
"crnn_vgg16_bn",
29+
"crnn_mobilenet_v3_small",
30+
"crnn_mobilenet_v3_large",
31+
"master",
32+
"sar_resnet31",
33+
"vitstr_small",
34+
"vitstr_base",
35+
"parseq",
2436
}
2537

2638

27-
class DocTR(DocumentParsingModel[List[np.ndarray], ImageDimensions, Document]):
39+
class DocTR(StructuredOCRModel[List[np.ndarray], ImageDimensions, Document]):
2840

2941
@classmethod
3042
def from_pretrained(
3143
cls,
3244
model_name_or_path: str,
3345
device: torch.device = DEFAULT_DEVICE,
46+
assume_straight_pages: bool = True,
47+
preserve_aspect_ratio: bool = True,
48+
detection_max_batch_size: int = 2,
49+
recognition_max_batch_size: int = 128,
3450
**kwargs,
35-
) -> "DocumentParsingModel":
36-
os.environ["DOCTR_CACHE_DIR"] = model_name_or_path
51+
) -> "StructuredOCRModel":
3752
model_package_content = get_model_package_contents(
3853
model_package_dir=model_name_or_path,
39-
elements=["doctr_det", "doctr_rec", "config.json"],
54+
elements=["detection_weights.pt", "recognition_weights.pt", "config.json"],
4055
)
4156
config = parse_model_config(config_path=model_package_content["config.json"])
42-
os.makedirs(f"{model_name_or_path}/doctr_det/models/", exist_ok=True)
43-
os.makedirs(f"{model_name_or_path}/doctr_rec/models/", exist_ok=True)
44-
det_model_source_path = os.path.join(
45-
model_name_or_path, "doctr_det", config.det_model, "model.pt"
46-
)
47-
rec_model_source_path = os.path.join(
48-
model_name_or_path, "doctr_rec", config.rec_model, "model.pt"
49-
)
50-
if not os.path.exists(det_model_source_path):
51-
raise CorruptedModelPackageError(
52-
message="Could not initialize DocTR model - could not find detection model weights.",
53-
help_url="https://todo",
54-
)
55-
if not os.path.exists(rec_model_source_path):
56-
raise CorruptedModelPackageError(
57-
message="Could not initialize DocTR model - could not find recognition model weights.",
58-
help_url="https://todo",
59-
)
60-
if config.det_model not in WEIGHTS_NAMES_MAPPING:
57+
if config.det_model not in SUPPORTED_DETECTION_MODELS:
6158
raise CorruptedModelPackageError(
6259
message=f"{config.det_model} model denoted in configuration not supported as DocTR detection model.",
6360
help_url="https://todo",
6461
)
65-
if config.rec_model not in WEIGHTS_NAMES_MAPPING:
62+
if config.rec_model not in SUPPORTED_RECOGNITION_MODELS:
6663
raise CorruptedModelPackageError(
67-
message=f"{config.det_model} model denoted in configuration not supported as DocTR recognition model.",
64+
message=f"{config.rec_model} model denoted in configuration not supported as DocTR recognition model.",
6865
help_url="https://todo",
6966
)
70-
det_model_target_path = os.path.join(
71-
model_name_or_path, "models", WEIGHTS_NAMES_MAPPING[config.det_model]
67+
det_model = detection_predictor(
68+
arch=config.det_model,
69+
pretrained=False,
70+
assume_straight_pages=assume_straight_pages,
71+
preserve_aspect_ratio=preserve_aspect_ratio,
72+
batch_size=detection_max_batch_size,
73+
)
74+
det_model.model.to(device)
75+
detector_weights = torch.load(
76+
model_package_content["detection_weights.pt"],
77+
weights_only=True,
78+
map_location=device,
79+
)
80+
det_model.model.load_state_dict(detector_weights)
81+
rec_model = recognition_predictor(
82+
arch=config.rec_model,
83+
pretrained=False,
84+
batch_size=recognition_max_batch_size,
7285
)
73-
rec_model_target_path = os.path.join(
74-
model_name_or_path, "models", WEIGHTS_NAMES_MAPPING[config.rec_model]
86+
rec_model.model.to(device)
87+
rec_weights = torch.load(
88+
model_package_content["recognition_weights.pt"],
89+
weights_only=True,
90+
map_location=device,
7591
)
76-
if os.path.exists(det_model_target_path):
77-
os.remove(det_model_target_path)
78-
os.symlink(det_model_source_path, det_model_target_path)
79-
if os.path.exists(rec_model_target_path):
80-
os.remove(rec_model_target_path)
81-
os.symlink(rec_model_source_path, rec_model_target_path)
92+
rec_model.model.load_state_dict(rec_weights)
8293
model = ocr_predictor(
83-
det_arch=config.det_model,
84-
reco_arch=config.rec_model,
85-
pretrained=True,
94+
det_arch=det_model.model,
95+
reco_arch=rec_model.model,
8696
).to(device=device)
8797
return cls(model=model, device=device)
8898

inference_experimental/inference_exp/models/moondream2/moondream2_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def from_pretrained(
3737
if torch.mps.is_available():
3838
raise ModelRuntimeError(
3939
message=f"This model cannot run on Apple device with MPS unit - original implementation contains bug "
40-
f"preventing proper allocation of tensors which causes runtime error. Run this model on the "
41-
f"machine with Nvidia GPU or x86 CPU.",
40+
f"preventing proper allocation of tensors which causes runtime error. Run this model on the "
41+
f"machine with Nvidia GPU or x86 CPU.",
4242
help_url="https://todo",
4343
)
4444
model_package_content = get_model_package_contents(

inference_experimental/inference_exp/models/rfdetr/rfdetr_instance_segmentation_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def from_pretrained(
124124
model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
125125
checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
126126
model_config.num_classes = checkpoint_num_classes - 1
127-
model_config.resolution = inference_config.network_input.training_input_size.height
127+
model_config.resolution = (
128+
inference_config.network_input.training_input_size.height
129+
)
128130
model = build_model(config=model_config)
129131
model.load_state_dict(weights_dict)
130132
model = model.eval().to(device)

inference_experimental/inference_exp/models/rfdetr/rfdetr_object_detection_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def from_pretrained(
130130
model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
131131
checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
132132
model_config.num_classes = checkpoint_num_classes - 1
133-
model_config.resolution = inference_config.network_input.training_input_size.height
133+
model_config.resolution = (
134+
inference_config.network_input.training_input_size.height
135+
)
134136
model = build_model(config=model_config)
135137
model.load_state_dict(weights_dict)
136138
model = model.eval().to(device)

inference_experimental/tests/e2e_platform_tests/conftest.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
import requests
77
from filelock import FileLock
88

9-
ASSETS_DIR = os.path.abspath(
10-
os.path.join(os.path.dirname(__file__), "assets")
11-
)
9+
ASSETS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "assets"))
1210
DOG_IMAGE_PATH = os.path.join(ASSETS_DIR, "images", "dog.jpeg")
1311
DOG_IMAGE_URL = (
1412
"https://storage.googleapis.com/roboflow-tests-assets/test-images/dog.jpeg"
1513
)
14+
OCR_TEST_IMAGE_PATH = os.path.join(ASSETS_DIR, "ocr_test_image.png")
15+
OCR_TEST_IMAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/test-images/ocr_test_image.png"
1616

1717

1818
@pytest.fixture()
@@ -28,6 +28,14 @@ def dog_image_numpy() -> np.ndarray:
2828
return image
2929

3030

31+
@pytest.fixture(scope="function")
32+
def ocr_test_image_numpy() -> np.ndarray:
33+
_download_if_not_exists(file_path=OCR_TEST_IMAGE_PATH, url=OCR_TEST_IMAGE_URL)
34+
image = cv2.imread(OCR_TEST_IMAGE_PATH)
35+
assert image is not None, "Could not load OCR test image"
36+
return image
37+
38+
3139
def _download_if_not_exists(file_path: str, url: str, lock_timeout: int = 180) -> None:
3240
os.makedirs(os.path.dirname(file_path), exist_ok=True)
3341
lock_path = f"{file_path}.lock"

0 commit comments

Comments
 (0)