22import io
33import json
44import os
5- import pathlib
65import urllib
76import warnings
8-
97import cv2
108import requests
119from PIL import Image
12-
13- from roboflow .models .classification import ClassificationModel
14- from roboflow .models .object_detection import ObjectDetectionModel
1510from roboflow .config import *
11+ from roboflow .core .version import Version
1612
17-
13+ #version class that should return
1814class Project ():
19- def __init__ (self , api_key , dataset_slug , type , versions , access_token , publishable_key ):
15+ def __init__ (self , api_key , dataset_slug , type , workspace ):
2016 self .api_key = api_key
21- self .dataset_slug = dataset_slug
22- self .type = type
23- self .access_token = access_token
24- self .publishable_key = publishable_key
25- # Dictionary of versions + names
26- self .versions_and_names = versions
27- # List of all versions to choose from
28- self .versions = list (int (vers ) for vers in versions .keys ())
29-
30- def model (self , version , local = False ):
31- # Check if version number is an available version to choose from
32- if version not in self .versions :
33- raise RuntimeError (
34- str (version ) + " is an invalid version; please select a different version from " + str (self .versions ))
35-
36- # Check whether model exists before initializing model
37- model_info_response = requests .get (
38- API_URL + "/model/" + self .dataset_slug + "/" + str (version ) + "?access_token=" + self .access_token )
39- if model_info_response .status_code != 200 :
40- raise RuntimeError (model_info_response .text )
41-
42- model_info_response = model_info_response .json ()
43- # Return appropriate model if model does exist
44- if model_info_response ['exists' ]:
45- if self .type == "object-detection" :
46- return ObjectDetectionModel (self .api_key , self .dataset_slug , version , local = local )
47- elif self .type == "classification" :
48- return ClassificationModel (self .api_key , self .dataset_slug , version , local = local )
17+ self .name = dataset_slug
18+ self .category = type
19+ self .workspace = workspace
20+ self .all_versions = []
21+
22+ def get_version_information (self ):
23+
24+ slug_splitted = self .name .rsplit ("/" )
25+ p , w = slug_splitted [0 ], slug_splitted [1 ]
26+
27+ dataset_info = requests .get (API_URL + "/" + p + "/" + w + "?api_key=" + self .api_key )
28+
29+ # Throw error if dataset isn't valid/user doesn't have permissions to access the dataset
30+ if dataset_info .status_code != 200 :
31+ raise RuntimeError (dataset_info .text )
32+
33+ dataset_info = dataset_info .json ()['project' ]
34+ return dataset_info ['versions' ]
35+
36+ def list_versions (self ):
37+ version_info = self .get_version_information ()
38+ print (version_info )
39+
40+ def versions (self ):
41+ version_info = self .get_version_information ()
42+ version_array = []
43+ for a_version in version_info :
44+ version_object = Version ((self .category if 'model' in a_version else None ), self .api_key , self .name , a_version ['id' ], local = False )
45+ version_array .append (version_object )
46+
47+ return version_array
48+
49+ def version (self , version_number ):
50+
51+ version_info = self .get_version_information ()
52+
53+ for version_object in version_info :
54+
55+ current_version_num = os .path .basename (version_object ['id' ])
56+ if current_version_num == version_number :
57+ vers = Version (self .category , self .api_key , self .name , current_version_num , local = False )
58+ return vers
59+
60+ raise RuntimeError ("Version number {} is not found." .format (version_number ))
4961
5062 def __image_upload (self , image_path , hosted_image = False , split = "train" ):
63+
5164 # If image is not a hosted image
5265 if not hosted_image :
66+ project_name = os .path .basename (self .name )
67+ image_name = os .path .basename (image_path )
5368 # Construct URL for local image upload
5469 self .image_upload_url = "" .join ([
55- "https://api.roboflow.com/dataset/" , self . dataset_slug , "/upload" ,
70+ "https://api.roboflow.com/dataset/" , project_name , "/upload" ,
5671 "?api_key=" , self .api_key ,
57- "&name=" + os . path . basename ( image_path ) ,
72+ "&name=" + image_name ,
5873 "&split=" + split
5974 ])
75+
6076 # Convert to PIL Image
61- image = cv2 .cvtColor (image_path , cv2 .COLOR_BGR2RGB )
77+ img = cv2 .imread (image_path )
78+ image = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
79+
6280 pilImage = Image .fromarray (image )
6381 # Convert to JPEG Buffer
6482 buffered = io .BytesIO ()
@@ -67,14 +85,16 @@ def __image_upload(self, image_path, hosted_image=False, split="train"):
6785 img_str = base64 .b64encode (buffered .getvalue ())
6886 img_str = img_str .decode ("ascii" )
6987 # Post Base 64 Data to Upload API
88+
7089 response = requests .post (self .image_upload_url , data = img_str , headers = {
7190 "Content-Type" : "application/x-www-form-urlencoded"
7291 })
7392
93+
7494 else :
7595 # Hosted image upload url
7696 upload_url = "" .join ([
77- "https://api.roboflow.com/dataset/" + self .dataset_slug + "/upload" ,
97+ "https://api.roboflow.com/dataset/" + self .name + "/upload" ,
7898 "?api_key=" + self .api_key ,
7999 "&name=" + os .path .basename (image_path ),
80100 "&split=" + split ,
@@ -83,14 +103,15 @@ def __image_upload(self, image_path, hosted_image=False, split="train"):
83103 # Get response
84104 response = requests .post (upload_url )
85105 # Return response
106+
86107 return response
87108
88109 def __annotation_upload (self , annotation_path , image_id ):
89110 # Get annotation string
90111 annotation_string = open (annotation_path , "r" ).read ()
91112 # Set annotation upload url
92113 self .annotation_upload_url = "" .join ([
93- "https://api.roboflow.com/dataset/" , self .dataset_slug , "/annotate/" , image_id ,
114+ "https://api.roboflow.com/dataset/" , self .name , "/annotate/" , image_id ,
94115 "?api_key=" , self .api_key ,
95116 "&name=" + os .path .basename (annotation_path )
96117 ])
@@ -107,6 +128,7 @@ def upload(self, image_path=None, annotation_path=None, hosted_image=False, imag
107128 if image_path is not None :
108129 # Upload Image Response
109130 response = self .__image_upload (image_path , hosted_image = hosted_image , split = split )
131+
110132 # Get JSON response values
111133 try :
112134 success , image_id = response .json ()['success' ], response .json ()['id' ]
@@ -144,9 +166,9 @@ def upload(self, image_path=None, annotation_path=None, hosted_image=False, imag
144166 def __str__ (self ):
145167 # String representation of project
146168 json_str = {
147- "dataset_slug" : self .dataset_slug ,
148- "dataset_type " : self .type ,
149- "dataset_versions " : self .versions_and_names
169+ "dataset_slug" : self .name ,
170+ "task_type " : self .category ,
171+ "workspace " : self .workspace
150172 }
151173
152174 return json .dumps (json_str , indent = 2 )
0 commit comments