Skip to content

Commit a408242

Browse files
author
Francesco Saverio Zuppichini
authored
removed check for extensions, PIL will do that for us (#86)
* removed check for extensions, PIL will do that for us * removed unused line * added cv2 for arrays * added check for image_path type * make style * added tests for object detection * added tests for object detection * not image file in tests
1 parent bd07c2d commit a408242

File tree

12 files changed

+353
-122
lines changed

12 files changed

+353
-122
lines changed

roboflow/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
SEMANTIC_SEGMENTATION_URL = os.getenv(
1515
"SEMANTIC_SEGMENTATION_URL", "https://segment.roboflow.com"
1616
)
17+
OBJECT_DETECTION_URL = os.getenv(
18+
"SEMANTIC_SEGMENTATION_URL", "https://detect.roboflow.com"
19+
)
1720

1821
CLIP_FEATURIZE_URL = os.getenv("CLIP_FEATURIZE_URL", "CLIP FEATURIZE URL NOT IN ENV")
1922
OCR_URL = os.getenv("OCR_URL", "OCR URL NOT IN ENV")

roboflow/models/object_detection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import json
44
import os
55
import urllib
6+
from pathlib import Path
67

78
import cv2
9+
import numpy as np
810
import requests
911
from PIL import Image
1012

@@ -142,23 +144,21 @@ def predict(
142144

143145
# If image is local image
144146
if not hosted:
145-
if ".jpg" in image_path or ".png" in image_path: # Open Image in RGB Format
147+
if type(image_path) is str:
146148
image = Image.open(image_path).convert("RGB")
147-
148149
# Create buffer
149150
buffered = io.BytesIO()
150151
image.save(buffered, format="PNG")
151152
# Base64 encode image
152153
img_str = base64.b64encode(buffered.getvalue())
153154
img_str = img_str.decode("ascii")
154-
155155
# Post to API and return response
156156
resp = requests.post(
157157
self.api_url,
158158
data=img_str,
159159
headers={"Content-Type": "application/x-www-form-urlencoded"},
160160
)
161-
else:
161+
elif isinstance(image_path, np.ndarray):
162162
# Performing inference on a OpenCV2 frame
163163
retval, buffer = cv2.imencode(".jpg", image_path)
164164
img_str = base64.b64encode(buffer)
@@ -169,15 +169,15 @@ def predict(
169169
data=img_str,
170170
headers={"Content-Type": "application/x-www-form-urlencoded"},
171171
)
172-
172+
else:
173+
raise ValueError("image_path must be a string or a numpy array.")
173174
else:
174175
# Create API URL for hosted image (slightly different)
175176
self.api_url += "&image=" + urllib.parse.quote_plus(image_path)
176177
# POST to the API
177178
resp = requests.get(self.api_url)
178179

179-
if resp.status_code != 200:
180-
raise Exception(resp.text)
180+
resp.raise_for_status()
181181
# Return a prediction group if JSON data
182182
if self.format == "json":
183183
return PredictionGroup.create_prediction_group(

roboflow/util/clip_compare_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import base64
2-
import glob
32
import io
43
import json
54

tests/__init__.py

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def compare(a, b):
3333

3434

3535
class RoboflowTest(unittest.TestCase):
36-
3736
def setUp(self):
3837
super(RoboflowTest, self).setUp()
3938
responses.start()
@@ -46,55 +45,125 @@ def setUp(self):
4645
"welcome": "Welcome to the Roboflow API.",
4746
"instructions": "You are successfully authenticated.",
4847
"docs": "https://docs.roboflow.com",
49-
"workspace": WORKSPACE_NAME
48+
"workspace": WORKSPACE_NAME,
5049
},
51-
status=200
50+
status=200,
5251
)
5352

5453
# Get workspace
5554
responses.add(
5655
responses.GET,
5756
f"{API_URL}/{WORKSPACE_NAME}?api_key={ROBOFLOW_API_KEY}",
5857
json={
59-
'workspace': {
60-
'name': WORKSPACE_NAME,
61-
'url': WORKSPACE_NAME,
62-
'members': 1,
63-
'projects': [
64-
{'id': f'{WORKSPACE_NAME}/{PROJECT_NAME}', 'type': 'object-detection', 'name': 'Hard Hat Sample', 'created': 1593802673.521, 'updated': 1663269501.654, 'images': 100, 'unannotated': 3, 'annotation': 'Workers', 'versions': 2, 'public': False, 'splits': {'train': 70, 'test': 10, 'valid': 20}, 'colors': {'head': '#8622FF', 'person': '#FF00FF', 'helmet': '#C7FC00'}, 'classes': {'person': 9, 'helmet': 287, 'head': 90}}
65-
]
58+
"workspace": {
59+
"name": WORKSPACE_NAME,
60+
"url": WORKSPACE_NAME,
61+
"members": 1,
62+
"projects": [
63+
{
64+
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}",
65+
"type": "object-detection",
66+
"name": "Hard Hat Sample",
67+
"created": 1593802673.521,
68+
"updated": 1663269501.654,
69+
"images": 100,
70+
"unannotated": 3,
71+
"annotation": "Workers",
72+
"versions": 2,
73+
"public": False,
74+
"splits": {"train": 70, "test": 10, "valid": 20},
75+
"colors": {
76+
"head": "#8622FF",
77+
"person": "#FF00FF",
78+
"helmet": "#C7FC00",
79+
},
80+
"classes": {"person": 9, "helmet": 287, "head": 90},
81+
}
82+
],
6683
}
6784
},
68-
status=200
85+
status=200,
6986
)
7087

7188
# Get project
7289
responses.add(
7390
responses.GET,
7491
f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}?api_key={ROBOFLOW_API_KEY}",
7592
json={
76-
'workspace': {
77-
'name': WORKSPACE_NAME,
78-
'url': WORKSPACE_NAME,
79-
'members': 1
93+
"workspace": {
94+
"name": WORKSPACE_NAME,
95+
"url": WORKSPACE_NAME,
96+
"members": 1,
8097
},
81-
'project': {
82-
'id': f'{WORKSPACE_NAME}/{PROJECT_NAME}', 'type': 'object-detection', 'name': 'Hard Hat Sample', 'created': 1593802673.521, 'updated': 1663269501.654, 'images': 100, 'unannotated': 3, 'annotation': 'Workers', 'versions': 2, 'public': False, 'splits': {'test': 10, 'train': 70, 'valid': 20}, 'colors': {'person': '#FF00FF', 'helmet': '#C7FC00', 'head': '#8622FF'}, 'classes': {'person': 9, 'helmet': 287, 'head': 90}
98+
"project": {
99+
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}",
100+
"type": "object-detection",
101+
"name": "Hard Hat Sample",
102+
"created": 1593802673.521,
103+
"updated": 1663269501.654,
104+
"images": 100,
105+
"unannotated": 3,
106+
"annotation": "Workers",
107+
"versions": 2,
108+
"public": False,
109+
"splits": {"test": 10, "train": 70, "valid": 20},
110+
"colors": {
111+
"person": "#FF00FF",
112+
"helmet": "#C7FC00",
113+
"head": "#8622FF",
114+
},
115+
"classes": {"person": 9, "helmet": 287, "head": 90},
83116
},
84-
'versions': [
85-
{'id': f'{WORKSPACE_NAME}/{PROJECT_NAME}/2', 'name': 'augmented-416x416', 'created': 1663104679.539, 'images': 240, 'splits': {'train': 210, 'test': 10, 'valid': 20}, 'preprocessing': {'resize': {'height': '416', 'enabled': True, 'width': '416', 'format': 'Stretch to'}, 'auto-orient': {'enabled': True}}, 'augmentation': {'blur': {'enabled': True, 'pixels': 1.5}, 'image': {'enabled': True, 'versions': 3}, 'rotate': {'degrees': 15, 'enabled': True}, 'crop': {'enabled': True, 'percent': 40, 'min': 0}, 'flip': {'horizontal': True, 'enabled': True, 'vertical': False}}, 'exports': []},
86-
{'id': f'{WORKSPACE_NAME}/{PROJECT_NAME}/1', 'name': 'raw', 'created': 1663104679.538, 'images': 100, 'splits': {'train': 70, 'test': 10, 'valid': 20}, 'preprocessing': {}, 'augmentation': {}, 'exports': []}
87-
]
117+
"versions": [
118+
{
119+
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}/2",
120+
"name": "augmented-416x416",
121+
"created": 1663104679.539,
122+
"images": 240,
123+
"splits": {"train": 210, "test": 10, "valid": 20},
124+
"preprocessing": {
125+
"resize": {
126+
"height": "416",
127+
"enabled": True,
128+
"width": "416",
129+
"format": "Stretch to",
130+
},
131+
"auto-orient": {"enabled": True},
132+
},
133+
"augmentation": {
134+
"blur": {"enabled": True, "pixels": 1.5},
135+
"image": {"enabled": True, "versions": 3},
136+
"rotate": {"degrees": 15, "enabled": True},
137+
"crop": {"enabled": True, "percent": 40, "min": 0},
138+
"flip": {
139+
"horizontal": True,
140+
"enabled": True,
141+
"vertical": False,
142+
},
143+
},
144+
"exports": [],
145+
},
146+
{
147+
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}/1",
148+
"name": "raw",
149+
"created": 1663104679.538,
150+
"images": 100,
151+
"splits": {"train": 70, "test": 10, "valid": 20},
152+
"preprocessing": {},
153+
"augmentation": {},
154+
"exports": [],
155+
},
156+
],
88157
},
89-
status=200
158+
status=200,
90159
)
91160

92161
# Upload image
93162
responses.add(
94163
responses.POST,
95164
f"{API_URL}/dataset/{PROJECT_NAME}/upload?api_key={ROBOFLOW_API_KEY}&batch={DEFAULT_BATCH_NAME}",
96-
json={'duplicate': True, 'id': 'hbALkCFdNr9rssgOUXug'},
97-
status=200
165+
json={"duplicate": True, "id": "hbALkCFdNr9rssgOUXug"},
166+
status=200,
98167
)
99168

100169
self.connect_to_roboflow()

tests/images/not_an_image.txt

Whitespace-only changes.

tests/models/test_instance_segmentation.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,10 @@
1818
"class": "J",
1919
"confidence": 0.598,
2020
"points": [
21-
{
22-
"x": 831.0,
23-
"y": 527.0
24-
},
25-
{
26-
"x": 931.0,
27-
"y": 389.0
28-
},
29-
{
30-
"x": 831.0,
31-
"y": 527.0
32-
}
33-
]
21+
{"x": 831.0, "y": 527.0},
22+
{"x": 931.0, "y": 389.0},
23+
{"x": 831.0, "y": 527.0},
24+
],
3425
},
3526
{
3627
"x": 363.8,
@@ -40,26 +31,13 @@
4031
"class": "K",
4132
"confidence": 0.52,
4233
"points": [
43-
{
44-
"x": 131.0,
45-
"y": 999.0
46-
},
47-
{
48-
"x": 269.0,
49-
"y": 666.0
50-
},
51-
52-
{
53-
"x": 131.0,
54-
"y": 999.0
55-
}
56-
]
57-
}
34+
{"x": 131.0, "y": 999.0},
35+
{"x": 269.0, "y": 666.0},
36+
{"x": 131.0, "y": 999.0},
37+
],
38+
},
5839
],
59-
"image": {
60-
"width": 1333,
61-
"height": 1000
62-
}
40+
"image": {"width": 1333, "height": 1000},
6341
}
6442

6543

@@ -85,7 +63,10 @@ def test_init_sets_attributes(self):
8563
instance = InstanceSegmentationModel(self.api_key, self.version_id)
8664

8765
self.assertEqual(instance.id, self.version_id)
88-
self.assertEqual(instance.api_url, f"{INSTANCE_SEGMENTATION_URL}/{self.dataset_id}/{self.version}")
66+
self.assertEqual(
67+
instance.api_url,
68+
f"{INSTANCE_SEGMENTATION_URL}/{self.dataset_id}/{self.version}",
69+
)
8970

9071
@responses.activate
9172
def test_predict_returns_prediction_group(self):
@@ -109,7 +90,7 @@ def test_predict_with_local_image_request(self):
10990

11091
request = responses.calls[0].request
11192

112-
self.assertEqual(request.method, 'POST')
93+
self.assertEqual(request.method, "POST")
11394
self.assertRegex(request.url, rf"^{self.api_url}")
11495
self.assertDictEqual(request.params, self._default_params)
11596
self.assertIsNotNone(request.body)
@@ -131,7 +112,7 @@ def test_predict_with_hosted_image_request(self):
131112

132113
request = responses.calls[1].request
133114

134-
self.assertEqual(request.method, 'POST')
115+
self.assertEqual(request.method, "POST")
135116
self.assertRegex(request.url, rf"^{self.api_url}")
136117
self.assertDictEqual(request.params, expected_params)
137118
self.assertIsNone(request.body)
@@ -140,10 +121,7 @@ def test_predict_with_hosted_image_request(self):
140121
def test_predict_with_confidence_request(self):
141122
confidence = "100"
142123
image_path = "tests/images/rabbit.JPG"
143-
expected_params = {
144-
**self._default_params,
145-
"confidence": confidence
146-
}
124+
expected_params = {**self._default_params, "confidence": confidence}
147125
instance = InstanceSegmentationModel(self.api_key, self.version_id)
148126

149127
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
@@ -152,7 +130,7 @@ def test_predict_with_confidence_request(self):
152130

153131
request = responses.calls[0].request
154132

155-
self.assertEqual(request.method, 'POST')
133+
self.assertEqual(request.method, "POST")
156134
self.assertRegex(request.url, rf"^{self.api_url}")
157135
self.assertDictEqual(request.params, expected_params)
158136
self.assertIsNotNone(request.body)

0 commit comments

Comments
 (0)