Skip to content

Commit f20526e

Browse files
author
Brad Dwyer
authored
Merge pull request #2 from roboflow-ai/final-changes
final changes to pip package
2 parents 2202e5c + f64e544 commit f20526e

File tree

7 files changed

+136
-78
lines changed

7 files changed

+136
-78
lines changed

requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ chardet==4.0.0
33
cycler==0.10.0
44
idna==2.10
55
kiwisolver==1.3.1
6-
matplotlib==3.3.4
6+
matplotlib
77
numpy==1.19.5
88
opencv-python==4.5.3.56
9-
Pillow==8.3.0
9+
Pillow
1010
pyparsing==2.4.7
11-
python-dateutil==2.8.1
11+
python-dateutil
1212
python-dotenv==0.18.0
1313
requests==2.25.1
14-
six==1.16.0
14+
six
1515
urllib3==1.26.6

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def project(self, project_name, the_workspace=None):
6868

6969
dataset_info = dataset_info.json()['project']
7070

71-
return Project(self.api_key, dataset_info['id'], dataset_info['type'], dataset_info['versions'])
71+
return Project(self.api_key, dataset_info)
7272

7373
def __str__(self):
7474
json_value = {'api_key': self.api_key,

roboflow/core/project.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
import urllib
6+
import datetime
67
import warnings
78
import cv2
89
import requests
@@ -12,25 +13,33 @@
1213

1314
#version class that should return
1415
class Project():
15-
def __init__(self, api_key, dataset_slug, type, workspace):
16-
self.api_key = api_key
17-
self.name = dataset_slug
18-
self.category = type
19-
self.workspace = workspace
20-
self.all_versions = []
16+
def __init__(self, api_key, a_project):
17+
self.__api_key = api_key
18+
self.annotation = a_project['annotation']
19+
self.classes = a_project['classes']
20+
self.colors = a_project['colors']
21+
self.created = datetime.datetime.fromtimestamp(a_project['created'])
22+
self.id = a_project['id']
23+
self.images = a_project['images']
24+
self.name = a_project['name']
25+
self.public = a_project['public']
26+
self.splits = a_project['splits']
27+
self.type = a_project['type']
28+
self.unannotated = a_project['unannotated']
29+
self.updated = datetime.datetime.fromtimestamp(a_project['updated'])
30+
31+
temp = self.id.rsplit("/")
32+
self.__workspace = temp[0]
33+
self.__project_name = temp[1]
2134

2235
def get_version_information(self):
23-
24-
slug_splitted = self.name.rsplit("/")
25-
p, w = slug_splitted[0], slug_splitted[1]
26-
27-
dataset_info = requests.get(API_URL + "/" + p + "/" + w + "?api_key=" + self.api_key)
36+
dataset_info = requests.get(API_URL + "/" + self.__workspace + "/" + self.__project_name + "?api_key=" + self.__api_key)
2837

2938
# Throw error if dataset isn't valid/user doesn't have permissions to access the dataset
3039
if dataset_info.status_code != 200:
3140
raise RuntimeError(dataset_info.text)
3241

33-
dataset_info = dataset_info.json()['project']
42+
dataset_info = dataset_info.json()
3443
return dataset_info['versions']
3544

3645
def list_versions(self):
@@ -41,9 +50,8 @@ def versions(self):
4150
version_info = self.get_version_information()
4251
version_array = []
4352
for a_version in version_info:
44-
version_object = Version((self.category if 'model' in a_version else None), self.api_key, self.name, a_version['id'], local=False)
53+
version_object = Version(a_version, (self.type if 'model' in a_version else None), self.__api_key, self.name, a_version['id'], local=False)
4554
version_array.append(version_object)
46-
4755
return version_array
4856

4957
def version(self, version_number):
@@ -53,8 +61,8 @@ def version(self, version_number):
5361
for version_object in version_info:
5462

5563
current_version_num = os.path.basename(version_object['id'])
56-
if current_version_num == version_number:
57-
vers = Version(self.category, self.api_key, self.name, current_version_num, local=False)
64+
if current_version_num == str(version_number):
65+
vers = Version(version_object, self.type, self.__api_key, self.name, current_version_num, local=False)
5866
return vers
5967

6068
raise RuntimeError("Version number {} is not found.".format(version_number))
@@ -63,12 +71,15 @@ def __image_upload(self, image_path, hosted_image=False, split="train"):
6371

6472
# If image is not a hosted image
6573
if not hosted_image:
66-
project_name = os.path.basename(self.name)
74+
75+
project_name = self.id.rsplit("/")[1]
6776
image_name = os.path.basename(image_path)
77+
6878
# Construct URL for local image upload
79+
6980
self.image_upload_url = "".join([
7081
"https://api.roboflow.com/dataset/", project_name, "/upload",
71-
"?api_key=", self.api_key,
82+
"?api_key=", self.__api_key,
7283
"&name=" + image_name,
7384
"&split=" + split
7485
])
@@ -93,9 +104,11 @@ def __image_upload(self, image_path, hosted_image=False, split="train"):
93104

94105
else:
95106
# Hosted image upload url
107+
project_name = self.id.rsplit("/")[1]
108+
96109
upload_url = "".join([
97-
"https://api.roboflow.com/dataset/" + self.name + "/upload",
98-
"?api_key=" + self.api_key,
110+
"https://api.roboflow.com/dataset/" + self.project_name + "/upload",
111+
"?api_key=" + self.__api_key,
99112
"&name=" + os.path.basename(image_path),
100113
"&split=" + split,
101114
"&image=" + urllib.parse.quote_plus(image_path)
@@ -112,7 +125,7 @@ def __annotation_upload(self, annotation_path, image_id):
112125
# Set annotation upload url
113126
self.annotation_upload_url = "".join([
114127
"https://api.roboflow.com/dataset/", self.name, "/annotate/", image_id,
115-
"?api_key=", self.api_key,
128+
"?api_key=", self.__api_key,
116129
"&name=" + os.path.basename(annotation_path)
117130
])
118131
# Get annotation response
@@ -128,7 +141,6 @@ def upload(self, image_path=None, annotation_path=None, hosted_image=False, imag
128141
if image_path is not None:
129142
# Upload Image Response
130143
response = self.__image_upload(image_path, hosted_image=hosted_image, split=split)
131-
132144
# Get JSON response values
133145
try:
134146
success, image_id = response.json()['success'], response.json()['id']
@@ -166,9 +178,9 @@ def upload(self, image_path=None, annotation_path=None, hosted_image=False, imag
166178
def __str__(self):
167179
# String representation of project
168180
json_str = {
169-
"dataset_slug": self.name,
170-
"task_type": self.category,
171-
"workspace": self.workspace
181+
"name": self.name,
182+
"type": self.type,
183+
"workspace": self.__workspace,
172184
}
173185

174186
return json.dumps(json_str, indent=2)

roboflow/core/version.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,34 @@
55

66

77
class Version():
8-
def __init__(self, type, api_key, dataset_slug, version, local):
9-
self.api_key = api_key
10-
self.name = dataset_slug
8+
def __init__(self, version_dict, type, api_key, name, version, local):
9+
self.__api_key = api_key
10+
self.name = name
1111
self.version = version
12-
self.category = type
12+
self.type = type
13+
self.augmentation = version_dict['augmentation']
14+
self.created = version_dict['created']
15+
self.id = version_dict['id']
16+
self.images = version_dict['images']
17+
self.preprocessing = version_dict['preprocessing']
18+
self.splits = version_dict['splits']
1319

1420
version_without_workspace = os.path.basename(version)
1521

16-
if self.category == "object-detection":
17-
self.model = ObjectDetectionModel(self.api_key, self.name, version_without_workspace, local=local)
18-
elif self.category == "classification":
19-
self.model = ClassificationModel(self.api_key, self.name, version_without_workspace, local=local)
22+
if self.type == "object-detection":
23+
self.model = ObjectDetectionModel(self.__api_key, self.id, self.name, version_without_workspace, local=local)
24+
elif self.type == "classification":
25+
self.model = ClassificationModel(self.__api_key, self.id, self.name, version_without_workspace, self.id, local=local)
2026
else:
2127
self.model = None
2228

2329
def __str__(self):
24-
json_value = {'api_key': self.api_key,
25-
'name': self.name,
26-
'model_type': str(self.model),
27-
'version': self.version}
30+
json_value = {
31+
'name': self.name,
32+
'type': self.type,
33+
'version': self.version,
34+
'augmentation': self.augmentation,
35+
'created': self.created,
36+
'preprocessing': self.preprocessing,
37+
'splits': self.splits}
2838
return json.dumps(json_value, indent=2)
29-
30-
31-

roboflow/core/workspace.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,52 @@
11
import requests
2+
import json
23
from roboflow.core.project import Project
34
from roboflow.config import *
45

56
class Workspace():
67
def __init__(self, info, api_key, default_workspace):
7-
8-
self.api_key = api_key
9-
self.name = default_workspace
10-
118
workspace_info = info['workspace']
9+
self.name = workspace_info['name']
10+
self.project_list = workspace_info['projects']
1211
self.members = workspace_info['members']
1312
self.url = workspace_info['url']
14-
self.project_list = []
1513

16-
for value in info['workspace']['projects']:
17-
self.project_list.append(value)
14+
self.__api_key = api_key
15+
1816

1917
def list_projects(self):
20-
print(self.projects)
18+
print(self.project_list)
2119

2220
def projects(self):
2321
projects_array = []
2422
for a_project in self.project_list:
25-
split = a_project['id'].rsplit("/")
26-
workspace, project_name = split[0], split[1]
27-
proj = Project(self.api_key, project_name, a_project['type'], workspace)
23+
proj = Project(self.__api_key, a_project)
2824
projects_array.append(proj)
2925

3026
return projects_array
3127

3228

3329
def project(self, project_name):
30+
project_name = project_name.replace(self.url + "/", "")
3431

3532
if "/" in project_name:
36-
raise RuntimeError("Do not re-specify the workspace {} in your project request".format(project_name.rsplit()[0]))
33+
raise RuntimeError("The {} project is not available in this ({}) workspace".format(project_name, self.url))
3734

38-
dataset_info = requests.get(API_URL + "/" + self.name + "/" + project_name + "?api_key=" + self.api_key)
35+
dataset_info = requests.get(API_URL + "/" + self.url + "/" + project_name + "?api_key=" + self.__api_key)
3936

4037
# Throw error if dataset isn't valid/user doesn't have permissions to access the dataset
4138
if dataset_info.status_code != 200:
4239
raise RuntimeError(dataset_info.text)
4340

4441
dataset_info = dataset_info.json()['project']
4542

46-
return Project(self.api_key, dataset_info['id'], dataset_info['type'], dataset_info['versions'])
43+
return Project(self.__api_key, dataset_info)
44+
45+
def __str__(self):
46+
json_value = {'name': self.name,
47+
'url': self.url,
48+
'members': self.members,
49+
'projects': self.projects
50+
}
51+
52+
return json.dumps(json_value, indent=2)

roboflow/models/classification.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
import os
44
import urllib
5+
import json
56

67
import requests
78
from PIL import Image
@@ -12,23 +13,24 @@
1213

1314

1415
class ClassificationModel:
15-
def __init__(self, api_key, dataset_slug=None, version=None, local=False):
16+
def __init__(self, api_key, id, name=None, version=None, local=False):
1617
"""
1718
1819
:param api_key:
19-
:param dataset_slug:
20+
:param name:
2021
:param version:
2122
"""
2223
# Instantiate different API URL parameters
23-
self.api_key = api_key
24-
self.dataset_slug = dataset_slug
24+
self.__api_key = api_key
25+
self.id=id
26+
self.name = name
2527
self.version = version
2628
if not local:
2729
self.base_url = "https://classify.roboflow.com/"
2830
else:
2931
self.base_url = "http://localhost:9001/"
3032

31-
if dataset_slug is not None and version is not None:
33+
if self.name is not None and version is not None:
3234
self.__generate_url()
3335

3436
def predict(self, image_path, hosted=False):
@@ -67,15 +69,15 @@ def predict(self, image_path, hosted=False):
6769
image_path=image_path,
6870
prediction_type=CLASSIFICATION_MODEL)
6971

70-
def load_model(self, dataset_slug, version):
72+
def load_model(self, name, version):
7173
"""
7274
73-
:param dataset_slug:
75+
:param name:
7476
:param version:
7577
:return:
7678
"""
7779
# Load model based on user defined characteristics
78-
self.dataset_slug = dataset_slug
80+
self.name = name
7981
self.version = version
8082
self.__generate_url()
8183

@@ -86,10 +88,12 @@ def __generate_url(self):
8688
"""
8789

8890
# Generates URL based on all parameters
89-
without_workspace = os.path.basename(self.dataset_slug)
91+
splitted = self.id.rsplit("/")
92+
without_workspace = splitted[1]
93+
9094
self.api_url = "".join([
9195
self.base_url + without_workspace + '/' + str(self.version),
92-
"?api_key=" + self.api_key,
96+
"?api_key=" + self.__api_key,
9397
"&name=YOUR_IMAGE.jpg"])
9498

9599
def __exception_check(self, image_path_check=None):
@@ -102,3 +106,10 @@ def __exception_check(self, image_path_check=None):
102106
if image_path_check is not None:
103107
if not os.path.exists(image_path_check) and not check_image_url(image_path_check):
104108
raise Exception("Image does not exist at " + image_path_check + "!")
109+
110+
def __str__(self):
111+
json_value = {'name': self.name,
112+
'version': self.version,
113+
'base_url': self.base_url}
114+
115+
return json.dumps(json_value, indent=2)

0 commit comments

Comments
 (0)