Skip to content

Commit 09c5704

Browse files
author
marwan37
committed
update loader to work with Dataframe directly, and not a dict
1 parent 2d4b4bf commit 09c5704

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

omni-reader/steps/loaders.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def load_images(
6060
matching_files = glob.glob(full_pattern)
6161
if matching_files:
6262
all_images.extend(matching_files)
63-
logger.info(f"Found {len(matching_files)} images matching pattern {pattern}")
63+
logger.info(
64+
f"Found {len(matching_files)} images matching pattern {pattern}"
65+
)
6466

6567
# Validate image paths
6668
valid_images = []
@@ -72,7 +74,9 @@ def load_images(
7274

7375
# Log metadata about the loaded images
7476
image_names = [os.path.basename(path) for path in valid_images]
75-
image_extensions = [os.path.splitext(path)[1].lower() for path in valid_images]
77+
image_extensions = [
78+
os.path.splitext(path)[1].lower() for path in valid_images
79+
]
7680

7781
extension_counts = {}
7882
for ext in image_extensions:
@@ -98,16 +102,18 @@ def load_images(
98102

99103
@step(enable_cache=False)
100104
def load_ground_truth_texts(
101-
model_results: Dict[str, pl.DataFrame],
105+
model_results: pl.DataFrame,
102106
ground_truth_folder: Optional[str] = None,
103107
ground_truth_files: Optional[List[str]] = None,
104108
) -> Annotated[pl.DataFrame, "ground_truth"]:
105109
"""Load ground truth texts using image names found in model results."""
106110
if not ground_truth_folder and not ground_truth_files:
107-
raise ValueError("Either ground_truth_folder or ground_truth_files must be provided")
111+
raise ValueError(
112+
"Either ground_truth_folder or ground_truth_files must be provided"
113+
)
108114

109115
# Get the first model column to extract image names
110-
first_model_column = list(model_results.keys())[0]
116+
first_model_column = list(model_results.columns)[0]
111117

112118
image_names = model_results[first_model_column]["image_name"].to_list()
113119

@@ -182,11 +188,15 @@ def load_ocr_results(
182188
try:
183189
client = Client()
184190

185-
artifact = client.get_artifact_version(name_id_or_prefix=artifact_name, version=version)
191+
artifact = client.get_artifact_version(
192+
name_id_or_prefix=artifact_name, version=version
193+
)
186194

187195
ocr_results = load_artifact(artifact.id)
188196

189-
logger.info(f"Successfully loaded OCR results for {len(ocr_results)} models")
197+
logger.info(
198+
f"Successfully loaded OCR results for {len(ocr_results)} models"
199+
)
190200

191201
return ocr_results
192202
except Exception as e:

0 commit comments

Comments
 (0)