Skip to content

Commit 12e221d

Browse files
committed
return predictions to user for further use
1 parent 01b700e commit 12e221d

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

roboflow/core/workspace.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ def active_learning(
241241
conditionals: (dict) = dictionary of upload conditions
242242
use_localhost: (bool) = determines if local http format used or remote endpoint
243243
"""
244+
prediction_results = []
244245

245246
# ensure that all fields of conditionals have a key:value pair
246-
247247
conditionals["target_classes"] = (
248248
[]
249249
if "target_classes" not in conditionals
@@ -331,6 +331,13 @@ def active_learning(
331331
continue # skip this image if too similar or counter hits limit
332332

333333
predictions = inference_model.predict(image).json()["predictions"]
334+
# collect all predictions to return to user at end
335+
prediction_results.append(
336+
{
337+
"image":image,
338+
"predictions":predictions
339+
}
340+
)
334341

335342
# compare object and class count of predictions if enabled, continue if not enough occurances
336343
if not count_comparisons(
@@ -376,8 +383,9 @@ def active_learning(
376383
print(" >> image uploaded!")
377384
upload_project.upload(image, num_retry_uploads=3)
378385
break
379-
380-
return "complete"
386+
387+
# return predictions with filenames if globbed images from dir, otherwise return latest prediction result
388+
return prediction_results if type(raw_data_location) is not ndarray else prediction_results[-1]["predictions"]
381389

382390
def __str__(self):
383391
projects = self.projects()

0 commit comments

Comments
 (0)