Skip to content

Commit 96ce2a0

Browse files
Merge pull request #93 from roboflow/json-response-trt-fix
Fix for json_response["image"] in prediction.py
2 parents b5495c1 + 7d0fd36 commit 96ce2a0

File tree

6 files changed

+31
-40
lines changed

6 files changed

+31
-40
lines changed

roboflow/core/project.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,10 +534,6 @@ def single_upload(
534534

535535
def __str__(self):
536536
# String representation of project
537-
json_str = {
538-
"name": self.name,
539-
"type": self.type,
540-
"workspace": self.__workspace,
541-
}
537+
json_str = {"name": self.name, "type": self.type, "workspace": self.__workspace}
542538

543539
return json.dumps(json_str, indent=2)

roboflow/core/version.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,9 @@ def __init__(
9898
local=local,
9999
)
100100
elif self.type == TYPE_INSTANCE_SEGMENTATION:
101-
self.model = InstanceSegmentationModel(
102-
self.__api_key,
103-
self.id,
104-
)
101+
self.model = InstanceSegmentationModel(self.__api_key, self.id)
105102
elif self.type == TYPE_SEMANTIC_SEGMENTATION:
106-
self.model = SemanticSegmentationModel(
107-
self.__api_key,
108-
self.id,
109-
)
103+
self.model = SemanticSegmentationModel(self.__api_key, self.id)
110104
else:
111105
self.model = None
112106

@@ -485,10 +479,7 @@ def __get_format_identifier(self, format):
485479
"You must pass a format argument to version.download() or define a model in your Roboflow object"
486480
)
487481

488-
friendly_formats = {
489-
"yolov5": "yolov5pytorch",
490-
"yolov7": "yolov7pytorch",
491-
}
482+
friendly_formats = {"yolov5": "yolov5pytorch", "yolov7": "yolov7pytorch"}
492483
return friendly_formats.get(format, format)
493484

494485
def __reformat_yaml(self, location, format):

roboflow/models/inference.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,25 @@ def __get_image_params(self, image_path):
3131
"""
3232
validate_image_path(image_path)
3333

34-
hosted_image = urllib.parse.urlparse(image_path).scheme in (
35-
"http",
36-
"https",
37-
)
34+
hosted_image = urllib.parse.urlparse(image_path).scheme in ("http", "https")
3835

3936
if hosted_image:
40-
return {"image": image_path}, {}
37+
image_dims = {"width": "Undefined", "height": "Undefined"}
38+
return {"image": image_path}, {}, image_dims
4139

4240
image = Image.open(image_path)
41+
dimensions = image.size
42+
image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])}
4343
buffered = io.BytesIO()
4444
image.save(buffered, quality=90, format="JPEG")
4545
data = MultipartEncoder(
4646
fields={"file": ("imageToUpload", buffered.getvalue(), "image/jpeg")}
4747
)
48-
return {}, {
49-
"data": data,
50-
"headers": {"Content-Type": data.content_type},
51-
}
48+
return (
49+
{},
50+
{"data": data, "headers": {"Content-Type": data.content_type}},
51+
image_dims,
52+
)
5253

5354
def predict(self, image_path, prediction_type=None, **kwargs):
5455
"""
@@ -60,7 +61,7 @@ def predict(self, image_path, prediction_type=None, **kwargs):
6061
:return: PredictionGroup - a group of predictions based on Roboflow JSON response
6162
:raises Exception: Image path is not valid
6263
"""
63-
params, request_kwargs = self.__get_image_params(image_path)
64+
params, request_kwargs, image_dims = self.__get_image_params(image_path)
6465

6566
params["api_key"] = self.__api_key
6667

@@ -74,4 +75,5 @@ def predict(self, image_path, prediction_type=None, **kwargs):
7475
response.json(),
7576
image_path=image_path,
7677
prediction_type=prediction_type,
78+
image_dims=image_dims,
7779
)

roboflow/models/object_detection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def predict(
146146
if not hosted:
147147
if type(image_path) is str:
148148
image = Image.open(image_path).convert("RGB")
149+
dimensions = image.size
149150
# Create buffer
150151
buffered = io.BytesIO()
151152
image.save(buffered, format="PNG")
@@ -158,22 +159,27 @@ def predict(
158159
data=img_str,
159160
headers={"Content-Type": "application/x-www-form-urlencoded"},
160161
)
162+
image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])}
161163
elif isinstance(image_path, np.ndarray):
162164
# Performing inference on a OpenCV2 frame
163165
retval, buffer = cv2.imencode(".jpg", image_path)
166+
# Currently cv2.imencode does not properly return shape
167+
dimensions = buffer.shape
164168
img_str = base64.b64encode(buffer)
165-
# print(img_str)
166169
img_str = img_str.decode("ascii")
167170
resp = requests.post(
168171
self.api_url,
169172
data=img_str,
170173
headers={"Content-Type": "application/x-www-form-urlencoded"},
171174
)
175+
# Replace with dimensions variable once cv2.imencode shape solution is found
176+
image_dims = {"width": "0", "height": "0"}
172177
else:
173178
raise ValueError("image_path must be a string or a numpy array.")
174179
else:
175180
# Create API URL for hosted image (slightly different)
176181
self.api_url += "&image=" + urllib.parse.quote_plus(image_path)
182+
image_dims = {"width": "0", "height": "0"}
177183
# POST to the API
178184
resp = requests.get(self.api_url)
179185

@@ -184,6 +190,7 @@ def predict(
184190
resp.json(),
185191
image_path=image_path,
186192
prediction_type=OBJECT_DETECTION_MODEL,
193+
image_dims=image_dims,
187194
)
188195
# Returns base64 encoded Data
189196
elif self.format == "image":

roboflow/util/image_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ def check_image_url(url):
2222
:param url: URL of image
2323
:returns: Boolean
2424
"""
25-
if urllib.parse.urlparse(url).scheme not in (
26-
"http",
27-
"https",
28-
):
25+
if urllib.parse.urlparse(url).scheme not in ("http", "https"):
2926
return False
3027

3128
r = requests.head(url)

roboflow/util/prediction.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,7 @@ def plot_annotation(axes, prediction=None, stroke=1, transparency=60):
8080
elif prediction["prediction_type"] == INSTANCE_SEGMENTATION_MODEL:
8181
points = [[p["x"], p["y"]] for p in prediction["points"]]
8282
polygon = patches.Polygon(
83-
points,
84-
linewidth=stroke,
85-
edgecolor="r",
86-
facecolor="none",
83+
points, linewidth=stroke, edgecolor="r", facecolor="none"
8784
)
8885
axes.add_patch(polygon)
8986
elif prediction["prediction_type"] == SEMANTIC_SEGMENTATION_MODEL:
@@ -476,14 +473,15 @@ def json(self):
476473
return prediction_group_json
477474

478475
@staticmethod
479-
def create_prediction_group(json_response, image_path, prediction_type):
476+
def create_prediction_group(json_response, image_path, prediction_type, image_dims):
480477
"""
481478
Method to create a prediction group based on the JSON Response
482479
483480
:param prediction_type:
484481
:param json_response: Based on Roboflow JSON Response from Inference API
485482
:param model:
486483
:param image_path:
484+
:param image_dims:
487485
:return:
488486
"""
489487
prediction_list = []
@@ -494,15 +492,15 @@ def create_prediction_group(json_response, image_path, prediction_type):
494492
prediction, image_path, prediction_type=prediction_type
495493
)
496494
prediction_list.append(prediction)
497-
img_dims = json_response["image"]
495+
img_dims = image_dims
498496
elif prediction_type == CLASSIFICATION_MODEL:
499497
prediction = Prediction(json_response, image_path, prediction_type)
500498
prediction_list.append(prediction)
501-
img_dims = {}
499+
img_dims = image_dims
502500
elif prediction_type == SEMANTIC_SEGMENTATION_MODEL:
503501
prediction = Prediction(json_response, image_path, prediction_type)
504502
prediction_list.append(prediction)
505-
img_dims = json_response["image"]
503+
img_dims = image_dims
506504

507505
# Seperate list and return as a prediction group
508506
return PredictionGroup(img_dims, image_path, *prediction_list)

0 commit comments

Comments
 (0)