Skip to content

Commit 9076bba

Browse files
author
Brad Dwyer
authored
Merge pull request #1 from roboflow-ai/rerouting
full refactor of pip package
2 parents bb63a9d + 8530c7a commit 9076bba

File tree

11 files changed

+225
-103
lines changed

11 files changed

+225
-103
lines changed

roboflow/.DS_Store

6 KB
Binary file not shown.

roboflow/__init__.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,67 +3,74 @@
33
import time
44

55
import requests
6-
6+
from roboflow.core.workspace import Workspace
77
from roboflow.core.project import Project
88
from roboflow.config import *
99

1010

11-
def auth(api_key):
11+
def check_key(api_key):
1212
if type(api_key) is not str:
1313
raise RuntimeError(
1414
"API Key is of Incorrect Type \n Expected Type: " + str(type("")) + "\n Input Type: " + str(type(api_key)))
1515

16-
response = requests.post(API_URL + "/token", data=({
17-
"api_key": api_key
18-
}))
19-
16+
response = requests.post(API_URL + "/?api_key=" + api_key)
2017
r = response.json()
18+
2119
if "error" in r or response.status_code != 200:
2220
raise RuntimeError(response.text)
21+
else:
22+
return r
23+
2324

24-
token = r['token']
25-
token_expires = r['expires_in']
26-
return Roboflow(api_key, token, token_expires)
25+
def auth(api_key):
26+
r = check_key(api_key)
27+
w = r['workspace']
28+
29+
return Roboflow(api_key, w)
2730

2831

2932
class Roboflow():
30-
def __init__(self, api_key, access_token, token_expires):
33+
def __init__(self, api_key):
3134
self.api_key = api_key
32-
self.access_token = access_token
33-
self.token_expires = token_expires
34-
# TODO: Need an endpoint to retrieve publishable key based on access token/workspace/api key
35-
publishable_key_response = requests.get(API_URL + "/key/publishable_key?access_token=" + self.access_token)
36-
if publishable_key_response.status_code != 200:
37-
raise RuntimeError(publishable_key_response.text)
38-
publishable_key_response = publishable_key_response.json()
39-
self.publishable_key = publishable_key_response['publishable_key']
40-
41-
def list_datasets(self):
42-
get_datasets_endpoint = API_URL + '/datasets'
43-
datasets = requests.get(get_datasets_endpoint + '?access_token=' + self.access_token).json()
44-
print(json.dumps(datasets, indent=2))
45-
return datasets
46-
47-
def load(self, dataset_slug):
48-
# Get info about dataset being loaded
49-
dataset_info = requests.get(API_URL + "/dataset/" + dataset_slug + "?access_token=" + self.access_token)
35+
self.auth()
36+
37+
def auth(self):
38+
r = check_key(self.api_key)
39+
w = r['workspace']
40+
41+
self.current_workspace=w
42+
43+
return self
44+
45+
def workspace(self, the_workspace=None):
46+
47+
if the_workspace is None:
48+
the_workspace = self.current_workspace
49+
50+
list_projects = requests.get(API_URL + "/" + the_workspace + '?api_key=' + self.api_key).json()
51+
52+
return Workspace(list_projects, self.api_key, the_workspace)
53+
54+
def project(self, project_name, the_workspace=None):
55+
56+
if the_workspace is None:
57+
if "/" in project_name:
58+
splitted_project = project_name.rsplit("/")
59+
the_workspace, project_name = splitted_project[0], splitted_project[1]
60+
else:
61+
the_workspace = self.current_workspace
62+
63+
dataset_info = requests.get(API_URL + "/" + the_workspace + "/" + project_name + "?api_key=" + self.api_key)
64+
5065
# Throw error if dataset isn't valid/user doesn't have permissions to access the dataset
5166
if dataset_info.status_code != 200:
5267
raise RuntimeError(dataset_info.text)
53-
# Turn dataset info into a json format otherwise
54-
dataset_info = dataset_info.json()
55-
# Get version info (i.e. version names + numbers)
56-
version_info = requests.get(
57-
API_URL + "/versions/dataset/" + dataset_slug + "?access_token=" + self.access_token)
58-
# Throw error if dataset isn't valid/user doesn't have permissions to access the dataset
59-
if version_info.status_code != 200:
60-
raise RuntimeError(version_info.text)
61-
# Turn dataset version info into a json format otherwise
62-
version_info = version_info.json()
63-
# Return a project object
64-
return Project(self.api_key, dataset_info['id'], dataset_info['type'], version_info['versions'],
65-
self.access_token, self.publishable_key)
68+
69+
dataset_info = dataset_info.json()['project']
70+
71+
return Project(self.api_key, dataset_info['id'], dataset_info['type'], dataset_info['versions'])
6672

6773
def __str__(self):
68-
json_value = {'api_key': self.api_key, 'auth_token': self.access_token, 'token_expires': self.token_expires}
74+
json_value = {'api_key': self.api_key,
75+
'workspace': self.workspace}
6976
return json.dumps(json_value, indent=2)

roboflow/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
OBJECT_DETECTION_MODEL = "ObjectDetectionModel"
22
CLASSIFICATION_MODEL = "ClassificationModel"
33
PREDICTION_OBJECT = "Prediction"
4-
API_URL = "http://localhost:5000"
5-
4+
API_URL = "https://api.roboflow.com"
65
from dotenv import load_dotenv
76

87
load_dotenv()

roboflow/core/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class Model():
2+
def __init__(self, model):
3+
self.id = model['id']
4+
5+
self.endpoint = model['endpoint']
6+
self.duration = model['end'] - model['start']
7+
self.statistics = {'recall': model['recall'], 'precision': model['precision'], 'map': model['map']}
8+
9+

roboflow/core/project.py

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,63 +2,81 @@
22
import io
33
import json
44
import os
5-
import pathlib
65
import urllib
76
import warnings
8-
97
import cv2
108
import requests
119
from PIL import Image
12-
13-
from roboflow.models.classification import ClassificationModel
14-
from roboflow.models.object_detection import ObjectDetectionModel
1510
from roboflow.config import *
11+
from roboflow.core.version import Version
1612

17-
13+
#version class that should return
1814
class Project():
19-
def __init__(self, api_key, dataset_slug, type, versions, access_token, publishable_key):
15+
def __init__(self, api_key, dataset_slug, type, workspace):
2016
self.api_key = api_key
21-
self.dataset_slug = dataset_slug
22-
self.type = type
23-
self.access_token = access_token
24-
self.publishable_key = publishable_key
25-
# Dictionary of versions + names
26-
self.versions_and_names = versions
27-
# List of all versions to choose from
28-
self.versions = list(int(vers) for vers in versions.keys())
29-
30-
def model(self, version, local=False):
31-
# Check if version number is an available version to choose from
32-
if version not in self.versions:
33-
raise RuntimeError(
34-
str(version) + " is an invalid version; please select a different version from " + str(self.versions))
35-
36-
# Check whether model exists before initializing model
37-
model_info_response = requests.get(
38-
API_URL + "/model/" + self.dataset_slug + "/" + str(version) + "?access_token=" + self.access_token)
39-
if model_info_response.status_code != 200:
40-
raise RuntimeError(model_info_response.text)
41-
42-
model_info_response = model_info_response.json()
43-
# Return appropriate model if model does exist
44-
if model_info_response['exists']:
45-
if self.type == "object-detection":
46-
return ObjectDetectionModel(self.api_key, self.dataset_slug, version, local=local)
47-
elif self.type == "classification":
48-
return ClassificationModel(self.api_key, self.dataset_slug, version, local=local)
17+
self.name = dataset_slug
18+
self.category = type
19+
self.workspace = workspace
20+
self.all_versions = []
21+
22+
def get_version_information(self):
23+
24+
slug_splitted = self.name.rsplit("/")
25+
p, w = slug_splitted[0], slug_splitted[1]
26+
27+
dataset_info = requests.get(API_URL + "/" + p + "/" + w + "?api_key=" + self.api_key)
28+
29+
# Throw error if dataset isn't valid/user doesn't have permissions to access the dataset
30+
if dataset_info.status_code != 200:
31+
raise RuntimeError(dataset_info.text)
32+
33+
dataset_info = dataset_info.json()['project']
34+
return dataset_info['versions']
35+
36+
def list_versions(self):
37+
version_info = self.get_version_information()
38+
print(version_info)
39+
40+
def versions(self):
41+
version_info = self.get_version_information()
42+
version_array = []
43+
for a_version in version_info:
44+
version_object = Version((self.category if 'model' in a_version else None), self.api_key, self.name, a_version['id'], local=False)
45+
version_array.append(version_object)
46+
47+
return version_array
48+
49+
def version(self, version_number):
50+
51+
version_info = self.get_version_information()
52+
53+
for version_object in version_info:
54+
55+
current_version_num = os.path.basename(version_object['id'])
56+
if current_version_num == version_number:
57+
vers = Version(self.category, self.api_key, self.name, current_version_num, local=False)
58+
return vers
59+
60+
raise RuntimeError("Version number {} is not found.".format(version_number))
4961

5062
def __image_upload(self, image_path, hosted_image=False, split="train"):
63+
5164
# If image is not a hosted image
5265
if not hosted_image:
66+
project_name = os.path.basename(self.name)
67+
image_name = os.path.basename(image_path)
5368
# Construct URL for local image upload
5469
self.image_upload_url = "".join([
55-
"https://api.roboflow.com/dataset/", self.dataset_slug, "/upload",
70+
"https://api.roboflow.com/dataset/", project_name, "/upload",
5671
"?api_key=", self.api_key,
57-
"&name=" + os.path.basename(image_path),
72+
"&name=" + image_name,
5873
"&split=" + split
5974
])
75+
6076
# Convert to PIL Image
61-
image = cv2.cvtColor(image_path, cv2.COLOR_BGR2RGB)
77+
img = cv2.imread(image_path)
78+
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
79+
6280
pilImage = Image.fromarray(image)
6381
# Convert to JPEG Buffer
6482
buffered = io.BytesIO()
@@ -67,14 +85,16 @@ def __image_upload(self, image_path, hosted_image=False, split="train"):
6785
img_str = base64.b64encode(buffered.getvalue())
6886
img_str = img_str.decode("ascii")
6987
# Post Base 64 Data to Upload API
88+
7089
response = requests.post(self.image_upload_url, data=img_str, headers={
7190
"Content-Type": "application/x-www-form-urlencoded"
7291
})
7392

93+
7494
else:
7595
# Hosted image upload url
7696
upload_url = "".join([
77-
"https://api.roboflow.com/dataset/" + self.dataset_slug + "/upload",
97+
"https://api.roboflow.com/dataset/" + self.name + "/upload",
7898
"?api_key=" + self.api_key,
7999
"&name=" + os.path.basename(image_path),
80100
"&split=" + split,
@@ -83,14 +103,15 @@ def __image_upload(self, image_path, hosted_image=False, split="train"):
83103
# Get response
84104
response = requests.post(upload_url)
85105
# Return response
106+
86107
return response
87108

88109
def __annotation_upload(self, annotation_path, image_id):
89110
# Get annotation string
90111
annotation_string = open(annotation_path, "r").read()
91112
# Set annotation upload url
92113
self.annotation_upload_url = "".join([
93-
"https://api.roboflow.com/dataset/", self.dataset_slug, "/annotate/", image_id,
114+
"https://api.roboflow.com/dataset/", self.name, "/annotate/", image_id,
94115
"?api_key=", self.api_key,
95116
"&name=" + os.path.basename(annotation_path)
96117
])
@@ -107,6 +128,7 @@ def upload(self, image_path=None, annotation_path=None, hosted_image=False, imag
107128
if image_path is not None:
108129
# Upload Image Response
109130
response = self.__image_upload(image_path, hosted_image=hosted_image, split=split)
131+
110132
# Get JSON response values
111133
try:
112134
success, image_id = response.json()['success'], response.json()['id']
@@ -144,9 +166,9 @@ def upload(self, image_path=None, annotation_path=None, hosted_image=False, imag
144166
def __str__(self):
145167
# String representation of project
146168
json_str = {
147-
"dataset_slug": self.dataset_slug,
148-
"dataset_type": self.type,
149-
"dataset_versions": self.versions_and_names
169+
"dataset_slug": self.name,
170+
"task_type": self.category,
171+
"workspace": self.workspace
150172
}
151173

152174
return json.dumps(json_str, indent=2)

roboflow/core/version.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from roboflow.models.classification import ClassificationModel
2+
from roboflow.models.object_detection import ObjectDetectionModel
3+
import os
4+
import json
5+
6+
7+
class Version():
8+
def __init__(self, type, api_key, dataset_slug, version, local):
9+
self.api_key = api_key
10+
self.name = dataset_slug
11+
self.version = version
12+
self.category = type
13+
14+
version_without_workspace = os.path.basename(version)
15+
16+
if self.category == "object-detection":
17+
self.model = ObjectDetectionModel(self.api_key, self.name, version_without_workspace, local=local)
18+
elif self.category == "classification":
19+
self.model = ClassificationModel(self.api_key, self.name, version_without_workspace, local=local)
20+
else:
21+
self.model = None
22+
23+
def __str__(self):
24+
json_value = {'api_key': self.api_key,
25+
'name': self.name,
26+
'model_type': str(self.model),
27+
'version': self.version}
28+
return json.dumps(json_value, indent=2)
29+
30+
31+

roboflow/core/workspace.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import requests
2+
from roboflow.core.project import Project
3+
from roboflow.config import *
4+
5+
class Workspace():
6+
def __init__(self, info, api_key, default_workspace):
7+
8+
self.api_key = api_key
9+
self.name = default_workspace
10+
11+
workspace_info = info['workspace']
12+
self.members = workspace_info['members']
13+
self.url = workspace_info['url']
14+
self.project_list = []
15+
16+
for value in info['workspace']['projects']:
17+
self.project_list.append(value)
18+
19+
def list_projects(self):
20+
print(self.projects)
21+
22+
def projects(self):
23+
projects_array = []
24+
for a_project in self.project_list:
25+
split = a_project['id'].rsplit("/")
26+
workspace, project_name = split[0], split[1]
27+
proj = Project(self.api_key, project_name, a_project['type'], workspace)
28+
projects_array.append(proj)
29+
30+
return projects_array
31+
32+
33+
def project(self, project_name):
34+
35+
if "/" in project_name:
36+
raise RuntimeError("Do not re-specify the workspace {} in your project request".format(project_name.rsplit()[0]))
37+
38+
dataset_info = requests.get(API_URL + "/" + self.name + "/" + project_name + "?api_key=" + self.api_key)
39+
40+
# Throw error if dataset isn't valid/user doesn't have permissions to access the dataset
41+
if dataset_info.status_code != 200:
42+
raise RuntimeError(dataset_info.text)
43+
44+
dataset_info = dataset_info.json()['project']
45+
46+
return Project(self.api_key, dataset_info['id'], dataset_info['type'], dataset_info['versions'])

0 commit comments

Comments
 (0)