Skip to content

Commit 86cd5bf

Browse files
authored
Merge pull request #312 from lrosemberg/rosemberg/refactoring-dataset-upload
Refactoring dataset upload
2 parents 6f1b08a + f391451 commit 86cd5bf

File tree

8 files changed

+363
-127
lines changed

8 files changed

+363
-127
lines changed

roboflow/adapters/rfapi.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,19 @@ class RoboflowError(Exception):
1414
pass
1515

1616

17-
class UploadError(RoboflowError):
18-
pass
17+
class ImageUploadError(RoboflowError):
18+
def __init__(self, message, status_code=None):
19+
self.message = message
20+
self.status_code = status_code
21+
self.retries = 0
22+
super().__init__(self.message)
23+
24+
25+
class AnnotationSaveError(RoboflowError):
26+
def __init__(self, message, status_code=None):
27+
self.message = message
28+
self.status_code = status_code
29+
super().__init__(self.message)
1930

2031

2132
def get_workspace(api_key, workspace_url):
@@ -78,24 +89,38 @@ def upload_image(
7889

7990
else:
8091
# Hosted image upload url
81-
8292
upload_url = _hosted_upload_url(api_key, project_url, image_path, split, coalesced_batch_name, tag_names)
93+
8394
# Get response
8495
response = requests.post(upload_url, timeout=(300, 300))
96+
8597
responsejson = None
8698
try:
8799
responsejson = response.json()
88100
except Exception:
89101
pass
102+
90103
if response.status_code != 200:
91104
if responsejson:
92-
raise UploadError(f"Bad response: {response.status_code}: {responsejson}")
105+
err_msg = responsejson
106+
107+
if err_msg.get("error"):
108+
err_msg = err_msg["error"]
109+
110+
if err_msg.get("message"):
111+
err_msg = err_msg["message"]
112+
113+
raise ImageUploadError(err_msg, status_code=response.status_code)
93114
else:
94-
raise UploadError(f"Bad response: {response}")
115+
raise ImageUploadError(str(response), status_code=response.status_code)
116+
95117
if not responsejson: # fail fast
96-
raise UploadError(f"upload image {image_path} 200 OK, unexpected response: {response}")
118+
raise ImageUploadError(str(response), status_code=response.status_code)
119+
97120
if not (responsejson.get("success") or responsejson.get("duplicate")):
98-
raise UploadError(f"Server rejected image: {responsejson}")
121+
message = responsejson.get("message") or str(responsejson)
122+
raise ImageUploadError(message)
123+
99124
return responsejson
100125

101126

@@ -128,24 +153,28 @@ def save_annotation(
128153
headers={"Content-Type": "application/json"},
129154
timeout=(60, 60),
130155
)
156+
157+
# Handle response
131158
responsejson = None
132159
try:
133160
responsejson = response.json()
134161
except Exception:
135162
pass
163+
136164
if not responsejson:
137-
raise _save_annotation_error(image_id, response)
165+
raise _save_annotation_error(response)
138166
if response.status_code not in (200, 409):
139-
raise _save_annotation_error(image_id, response)
167+
raise _save_annotation_error(response)
140168
if response.status_code == 409:
141169
if "already annotated" in responsejson.get("error", {}).get("message"):
142170
return {"warn": "already annotated"}
143171
else:
144-
raise _save_annotation_error(image_id, response)
172+
raise _save_annotation_error(response)
145173
if responsejson.get("error"):
146-
raise _save_annotation_error(image_id, response)
174+
raise _save_annotation_error(response)
147175
if not responsejson.get("success"):
148-
raise _save_annotation_error(image_id, response)
176+
raise _save_annotation_error(response)
177+
149178
return responsejson
150179

151180

@@ -191,17 +220,20 @@ def _local_upload_url(api_key, project_url, batch_name, tag_names, sequence_numb
191220
return _upload_url(api_key, project_url, **query_params)
192221

193222

194-
def _save_annotation_error(image_id, response):
195-
errmsg = f"save annotation for {image_id} / "
223+
def _save_annotation_error(response):
196224
responsejson = None
197225
try:
198226
responsejson = response.json()
199227
except Exception:
200228
pass
229+
201230
if not responsejson:
202-
errmsg += f"bad response: {response.status_code}: {response}"
203-
elif responsejson.get("error"):
204-
errmsg += f"bad response: {response.status_code}: {responsejson['error']}"
205-
else:
206-
errmsg += f"bad response: {response.status_code}: {responsejson}"
207-
return UploadError(errmsg)
231+
return AnnotationSaveError(response, status_code=response.status_code)
232+
233+
if responsejson.get("error"):
234+
err_msg = responsejson["error"]
235+
if err_msg.get("message"):
236+
err_msg = err_msg["message"]
237+
return AnnotationSaveError(err_msg, status_code=response.status_code)
238+
239+
return AnnotationSaveError(str(responsejson), status_code=response.status_code)

roboflow/core/project.py

Lines changed: 96 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import json
33
import mimetypes
44
import os
5-
import re
65
import sys
76
import time
87
import warnings
@@ -12,6 +11,7 @@
1211
import requests
1312

1413
from roboflow.adapters import rfapi
14+
from roboflow.adapters.rfapi import ImageUploadError
1515
from roboflow.config import API_URL, DEMO_KEYS
1616
from roboflow.core.version import Version
1717
from roboflow.util.general import Retry
@@ -465,6 +465,76 @@ def upload(
465465
print("[ " + path + " ] was skipped.")
466466
continue
467467

468+
def upload_image(
469+
self,
470+
image_path=None,
471+
hosted_image=False,
472+
split="train",
473+
num_retry_uploads=0,
474+
batch_name=None,
475+
tag_names=[],
476+
sequence_number=None,
477+
sequence_size=None,
478+
**kwargs,
479+
):
480+
project_url = self.id.rsplit("/")[1]
481+
482+
t0 = time.time()
483+
upload_retry_attempts = 0
484+
retry = Retry(num_retry_uploads, ImageUploadError)
485+
486+
try:
487+
image = retry(
488+
rfapi.upload_image,
489+
self.__api_key,
490+
project_url,
491+
image_path,
492+
hosted_image=hosted_image,
493+
split=split,
494+
batch_name=batch_name,
495+
tag_names=tag_names,
496+
sequence_number=sequence_number,
497+
sequence_size=sequence_size,
498+
**kwargs,
499+
)
500+
upload_retry_attempts = retry.retries
501+
except ImageUploadError as e:
502+
e.retries = upload_retry_attempts
503+
raise e
504+
505+
upload_time = time.time() - t0
506+
507+
return image, upload_time, upload_retry_attempts
508+
509+
def save_annotation(
510+
self,
511+
annotation_path=None,
512+
annotation_labelmap=None,
513+
image_id=None,
514+
job_name=None,
515+
is_prediction: bool = False,
516+
annotation_overwrite=False,
517+
):
518+
project_url = self.id.rsplit("/")[1]
519+
annotation_name, annotation_str = self._annotation_params(annotation_path)
520+
t0 = time.time()
521+
522+
annotation = rfapi.save_annotation(
523+
self.__api_key,
524+
project_url,
525+
annotation_name, # type: ignore[type-var]
526+
annotation_str, # type: ignore[type-var]
527+
image_id,
528+
job_name=job_name, # type: ignore[type-var]
529+
is_prediction=is_prediction,
530+
annotation_labelmap=annotation_labelmap,
531+
overwrite=annotation_overwrite,
532+
)
533+
534+
upload_time = time.time() - t0
535+
536+
return annotation, upload_time
537+
468538
def single_upload(
469539
self,
470540
image_path=None,
@@ -482,64 +552,41 @@ def single_upload(
482552
sequence_size=None,
483553
**kwargs,
484554
):
485-
project_url = self.id.rsplit("/")[1]
486555
if image_path and image_id:
487556
raise Exception("You can't pass both image_id and image_path")
488557
if not (image_path or image_id):
489558
raise Exception("You need to pass image_path or image_id")
490559
if isinstance(annotation_labelmap, str):
491560
annotation_labelmap = load_labelmap(annotation_labelmap)
561+
492562
uploaded_image, uploaded_annotation = None, None
493-
upload_time = None
563+
upload_time, annotation_time = None, None
494564
upload_retry_attempts = 0
565+
495566
if image_path:
496-
t0 = time.time()
497-
try:
498-
retry = Retry(num_retry_uploads, Exception)
499-
uploaded_image = retry(
500-
rfapi.upload_image,
501-
self.__api_key,
502-
project_url,
503-
image_path,
504-
hosted_image=hosted_image,
505-
split=split,
506-
batch_name=batch_name,
507-
tag_names=tag_names,
508-
sequence_number=sequence_number,
509-
sequence_size=sequence_size,
510-
**kwargs,
511-
)
512-
image_id = uploaded_image["id"] # type: ignore[index]
513-
upload_retry_attempts = retry.retries
514-
except rfapi.UploadError as e:
515-
raise RuntimeError(f"Error uploading image: {self._parse_upload_error(e)}")
516-
except BaseException as e:
517-
uploaded_image = {"error": e}
518-
finally:
519-
upload_time = time.time() - t0
520-
521-
annotation_time = None
567+
uploaded_image, upload_time, upload_retry_attempts = self.upload_image(
568+
image_path,
569+
hosted_image,
570+
split,
571+
num_retry_uploads,
572+
batch_name,
573+
tag_names,
574+
sequence_number,
575+
sequence_size,
576+
**kwargs,
577+
)
578+
image_id = uploaded_image["id"] # type: ignore[index]
579+
522580
if annotation_path and image_id:
523-
annotation_name, annotation_str = self._annotation_params(annotation_path)
524-
try:
525-
t0 = time.time()
526-
uploaded_annotation = rfapi.save_annotation(
527-
self.__api_key,
528-
project_url,
529-
annotation_name, # type: ignore[type-var]
530-
annotation_str, # type: ignore[type-var]
531-
image_id,
532-
job_name=batch_name, # type: ignore[type-var]
533-
is_prediction=is_prediction,
534-
annotation_labelmap=annotation_labelmap,
535-
overwrite=annotation_overwrite,
536-
)
537-
except rfapi.UploadError as e:
538-
raise RuntimeError(f"Error uploading annotation: {self._parse_upload_error(e)}")
539-
except BaseException as e:
540-
uploaded_annotation = {"error": e}
541-
finally:
542-
annotation_time = time.time() - t0
581+
uploaded_annotation, annotation_time = self.save_annotation(
582+
annotation_path,
583+
annotation_labelmap,
584+
image_id,
585+
batch_name,
586+
is_prediction,
587+
annotation_overwrite,
588+
)
589+
543590
return {
544591
"image": uploaded_image,
545592
"annotation": uploaded_annotation,
@@ -568,20 +615,6 @@ def _annotation_params(self, annotation_path):
568615
)
569616
return annotation_name, annotation_string
570617

571-
def _parse_upload_error(self, error: rfapi.UploadError) -> str:
572-
dict_part = str(error).split(": ", 2)[2]
573-
dict_part = dict_part.replace("True", "true")
574-
dict_part = dict_part.replace("False", "false")
575-
dict_part = dict_part.replace("None", "null")
576-
if re.search(r"'\w+':", dict_part):
577-
temp_str = dict_part.replace(r"\'", "<PLACEHOLDER>")
578-
temp_str = temp_str.replace('"', r"\"")
579-
temp_str = temp_str.replace("'", '"')
580-
dict_part = temp_str.replace("<PLACEHOLDER>", "'")
581-
parsed_dict: dict = json.loads(dict_part)
582-
message = parsed_dict.get("message")
583-
return message or str(parsed_dict)
584-
585618
def search(
586619
self,
587620
like_image: Optional[str] = None,

0 commit comments

Comments
 (0)