12
12
from dotenv import load_dotenv
13
13
from tqdm import tqdm
14
14
15
+ from roboflow .adapters import rfapi
15
16
from roboflow .config import (
16
17
API_URL ,
17
18
APP_URL ,
@@ -92,11 +93,11 @@ def __init__(
92
93
93
94
version_without_workspace = os .path .basename (str (version ))
94
95
95
- response = requests . get ( f" { API_URL } / { workspace } / { project } / { self . version } ?api_key= { self . __api_key } " )
96
- if response . ok :
97
- version_info = response . json ()[ "version" ]
96
+ try :
97
+ version_response = rfapi . get_version ( self . __api_key , workspace , project , self . version )
98
+ version_info = version_response . get ( "version" , {})
98
99
has_model = bool (version_info .get ("train" , {}).get ("model" ))
99
- else :
100
+ except rfapi . RoboflowError :
100
101
has_model = False
101
102
102
103
if not has_model :
@@ -152,16 +153,17 @@ def __init__(
152
153
153
154
def __check_if_generating (self ):
154
155
# check Roboflow API to see if this version is still generating
155
-
156
- url = f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } ?nocache=true"
157
- response = requests .get (url , params = {"api_key" : self .__api_key })
158
- response .raise_for_status ()
159
- if response .json ()["version" ]["progress" ] is None :
160
- progress = 0.0
161
- else :
162
- progress = float (response .json ()["version" ]["progress" ])
163
-
164
- return response .json ()["version" ]["generating" ], progress
156
+ versiondict = rfapi .get_version (
157
+ api_key = self .__api_key ,
158
+ workspace_url = self .workspace ,
159
+ project_url = self .project ,
160
+ version = self .version ,
161
+ nocache = True ,
162
+ )
163
+ version_obj = versiondict .get ("version" , {})
164
+ progress = 0.0 if version_obj .get ("progress" ) is None else float (version_obj .get ("progress" ))
165
+ generating = bool (version_obj .get ("generating" ) or version_obj .get ("images" , 0 ) == 0 )
166
+ return generating , progress
165
167
166
168
def __wait_if_generating (self , recurse = False ):
167
169
# checks if a given version is still in the progress of generating
@@ -219,15 +221,22 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
219
221
if self .__api_key == "coco-128-sample" :
220
222
link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP"
221
223
else :
222
- url = self .__get_download_url (model_format )
223
- response = requests .get (url , params = {"api_key" : self .__api_key })
224
- if response .status_code == 200 :
225
- link = response .json ()["export" ]["link" ]
226
- else :
227
- try :
228
- raise RuntimeError (response .json ())
229
- except json .JSONDecodeError :
230
- response .raise_for_status ()
224
+ workspace , project , * _ = self .id .rsplit ("/" )
225
+ try :
226
+ export_info = rfapi .get_version_export (
227
+ api_key = self .__api_key ,
228
+ workspace_url = workspace ,
229
+ project_url = project ,
230
+ version = self .version ,
231
+ format = model_format ,
232
+ )
233
+ except rfapi .RoboflowError as e :
234
+ raise RuntimeError (str (e ))
235
+
236
+ if "ready" in export_info and export_info .get ("ready" ) is False :
237
+ raise RuntimeError (export_info )
238
+
239
+ link = export_info ["export" ]["link" ]
231
240
232
241
self .__download_zip (link , location , model_format )
233
242
self .__extract_zip (location , model_format )
@@ -256,39 +265,36 @@ def export(self, model_format=None):
256
265
257
266
self .__wait_if_generating ()
258
267
259
- url = self .__get_download_url ( model_format )
260
- response = requests . get ( url , params = { "api_key" : self . __api_key })
261
- if not response . ok :
262
- try :
263
- raise RuntimeError ( response . json ())
264
- except json . JSONDecodeError :
265
- response . raise_for_status ()
266
-
267
- # the rest api returns 202 if the export is still in progress
268
- if response . status_code == 202 :
269
- status_code_check = 202
270
- while status_code_check == 202 :
271
- time . sleep ( 1 )
272
- response = requests . get ( url , params = { "api_key" : self . __api_key } )
273
- status_code_check = response . status_code
274
- if status_code_check == 202 :
275
- progress = response . json ()[ "progress" ]
276
- progress_message = (
277
- "Exporting format " + model_format + " in progress : " + str ( round ( progress * 100 , 2 )) + "%"
278
- )
279
- sys . stdout . write ( " \r " + progress_message )
280
- sys . stdout . flush ()
281
-
282
- if response . status_code == 200 :
268
+ workspace , project , * _ = self .id . rsplit ( "/" )
269
+ export_info = rfapi . get_version_export (
270
+ api_key = self . __api_key ,
271
+ workspace_url = workspace ,
272
+ project_url = project ,
273
+ version = self . version ,
274
+ format = model_format ,
275
+ )
276
+ while "ready" in export_info and export_info . get ( "ready" ) is False :
277
+ progress = export_info . get ( "progress" , 0.0 )
278
+ progress_message = (
279
+ "Exporting format " + model_format + " in progress : " + str ( round ( progress * 100 , 2 )) + "%"
280
+ )
281
+ sys . stdout . write ( " \r " + progress_message )
282
+ sys . stdout . flush ()
283
+ time . sleep ( 1 )
284
+ export_info = rfapi . get_version_export (
285
+ api_key = self . __api_key ,
286
+ workspace_url = workspace ,
287
+ project_url = project ,
288
+ version = self . version ,
289
+ format = model_format ,
290
+ )
291
+ if "export" in export_info :
283
292
sys .stdout .write ("\n " )
284
293
print ("\r " + "Version export complete for " + model_format + " format" )
285
294
sys .stdout .flush ()
286
295
return True
287
296
else :
288
- try :
289
- raise RuntimeError (response .json ())
290
- except json .JSONDecodeError :
291
- response .raise_for_status ()
297
+ raise RuntimeError (f"Unexpected export { export_info } " )
292
298
293
299
def train (self , speed = None , model_type = None , checkpoint = None , plot_in_notebook = False ) -> InferenceModel :
294
300
"""
@@ -326,28 +332,22 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
326
332
self .export (train_model_format )
327
333
328
334
workspace , project , * _ = self .id .rsplit ("/" )
329
- url = f"{ API_URL } /{ workspace } /{ project } /{ self .version } /train"
330
335
331
- data = {}
332
-
333
- if speed :
334
- data ["speed" ] = speed
335
-
336
- if checkpoint :
337
- data ["checkpoint" ] = checkpoint
338
-
339
- if model_type :
340
- # API expects camelCase key
341
- data ["modelType" ] = model_type
336
+ payload_speed = speed if speed else None
337
+ payload_checkpoint = checkpoint if checkpoint else None
338
+ payload_model_type = model_type if model_type else None
342
339
343
340
write_line ("Reaching out to Roboflow to start training..." )
344
341
345
- response = requests .post (url , json = data , params = {"api_key" : self .__api_key })
346
- if not response .ok :
347
- try :
348
- raise RuntimeError (response .json ())
349
- except json .JSONDecodeError :
350
- response .raise_for_status ()
342
+ rfapi .start_version_training (
343
+ api_key = self .__api_key ,
344
+ workspace_url = workspace ,
345
+ project_url = project ,
346
+ version = self .version ,
347
+ speed = payload_speed ,
348
+ checkpoint = payload_checkpoint ,
349
+ model_type = payload_model_type ,
350
+ )
351
351
352
352
status = "training"
353
353
@@ -374,10 +374,14 @@ def live_plot(epochs, mAP, loss, title=""):
374
374
num_machine_spin_dots = []
375
375
376
376
while status == "training" or status == "running" :
377
- url = f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } ?nocache=true"
378
- response = requests .get (url , params = {"api_key" : self .__api_key })
379
- response .raise_for_status ()
380
- version = response .json ()["version" ]
377
+ version_response = rfapi .get_version (
378
+ api_key = self .__api_key ,
379
+ workspace_url = self .workspace ,
380
+ project_url = self .project ,
381
+ version = self .version ,
382
+ nocache = True ,
383
+ )
384
+ version = version_response .get ("version" , {})
381
385
if "models" in version .keys ():
382
386
models = version ["models" ]
383
387
else :
0 commit comments