22import requests
33
44from .tasks import Task
5+ from .batches import Batch
56
67TASK_TYPES = [
78 'annotation' ,
1011 'comparison' ,
1112 'cuboidannotation' ,
1213 'datacollection' ,
13- 'imageannotation' ,
14+ 'imageannotation' ,
1415 'lineannotation' ,
1516 'namedentityrecognition' ,
1617 'pointannotation' ,
1718 'polygonannotation' ,
1819 'segmentannotation' ,
1920 'transcription' ,
20- 'videoannotation' ,
21- 'videoboxannotation' ,
21+ 'videoannotation' ,
22+ 'videoboxannotation' ,
2223 'videocuboidannotation'
2324]
2425SCALE_ENDPOINT = 'https://api.scale.com/v1/'
@@ -35,27 +36,37 @@ class ScaleInvalidRequest(ScaleException, ValueError):
3536 pass
3637
3738
38- class Tasklist (list ):
39+ class Paginator (list ):
3940 def __init__ (self , docs , total , limit , offset , has_more , next_token = None ):
40- super (Tasklist , self ).__init__ (docs )
41+ super (Paginator , self ).__init__ (docs )
4142 self .docs = docs
4243 self .total = total
4344 self .limit = limit
4445 self .offset = offset
4546 self .has_more = has_more
4647 self .next_token = next_token
4748
49+
50+ class Tasklist (Paginator ):
51+ pass
52+
53+
54+ class Batchlist (Paginator ):
55+ pass
56+
57+
4858class ScaleClient (object ):
4959 def __init__ (self , api_key ):
5060 self .api_key = api_key
5161
52- def _getrequest (self , endpoint , params = {} ):
62+ def _getrequest (self , endpoint , params = None ):
5363 """Makes a get request to an endpoint.
5464
5565 If an error occurs, assumes that endpoint returns JSON as:
5666 { 'status_code': XXX,
5767 'error': 'I failed' }
5868 """
69+ params = params or {}
5970 r = requests .get (SCALE_ENDPOINT + endpoint ,
6071 headers = {"Content-Type" : "application/json" },
6172 auth = (self .api_key , '' ), params = params )
@@ -114,7 +125,7 @@ def cancel_task(self, task_id):
114125 def tasks (self , ** kwargs ):
115126 """Returns a list of your tasks.
116127 Returns up to 100 at a time, to get more, use the next_token param passed back.
117-
128+
118129 Note that offset is deprecated.
119130
120131 start/end_time are ISO8601 dates, the time range of tasks to fetch.
@@ -125,7 +136,7 @@ def tasks(self, **kwargs):
125136 offset (deprecated) is the number of results to skip (for showing more pages).
126137 """
127138 allowed_kwargs = {'start_time' , 'end_time' , 'status' , 'type' , 'project' ,
128- 'batch' , 'limit' , 'offset' , 'completed_before' , 'completed_after' ,
139+ 'batch' , 'limit' , 'offset' , 'completed_before' , 'completed_after' ,
129140 'next_token' }
130141 for key in kwargs :
131142 if key not in allowed_kwargs :
@@ -140,6 +151,29 @@ def create_task(self, task_type, **kwargs):
140151 taskdata = self ._postrequest (endpoint , payload = kwargs )
141152 return Task (taskdata , self )
142153
154+ def create_batch (self , project , batch_name , callback ):
155+ payload = dict (project = project , name = batch_name , callback = callback )
156+ batchdata = self ._postrequest ('batches' , payload )
157+ return Batch (batchdata , self )
158+
159+ def get_batch (self , batch_name : str ):
160+ batchdata = self ._getrequest ('batches/%s' % batch_name )
161+ return Batch (batchdata , self )
162+
163+ def list_batches (self , ** kwargs ):
164+ allowed_kwargs = { 'start_time' , 'end_time' , 'status' , 'project' ,
165+ 'batch' , 'limit' , 'offset' , }
166+ for key in kwargs :
167+ if key not in allowed_kwargs :
168+ raise ScaleInvalidRequest ('Illegal parameter %s for ScaleClient.tasks()'
169+ % key , None )
170+ response = self ._getrequest ('tasks' , params = kwargs )
171+ docs = [Batch (doc , self ) for doc in response ['docs' ]]
172+ return Batchlist (
173+ docs , response ['total' ], response ['limit' ], response ['offset' ],
174+ response ['has_more' ], response .get ('next_token' ),
175+ )
176+
143177
144178def _AddTaskTypeCreator (task_type ):
145179 def create_task_wrapper (self , ** kwargs ):
0 commit comments