11import io
2+ import json
3+ import os
4+ import time
25import urllib
6+ from typing import List
7+ from urllib .parse import urljoin
38
49import requests
510from PIL import Image
611from requests_toolbelt .multipart .encoder import MultipartEncoder
712
13+ from roboflow .config import API_URL
814from roboflow .util .image_utils import validate_image_path
915from roboflow .util .prediction import PredictionGroup
1016
17+ SUPPORTED_ROBOFLOW_MODELS = ["batch-video" ]
18+
19+ SUPPORTED_ADDITIONAL_MODELS = {
20+ "clip" : {
21+ "model_id" : "clip" ,
22+ "model_version" : "1" ,
23+ "inference_type" : "clip-embed-image" ,
24+ },
25+ "gaze" : {
26+ "model_id" : "gaze" ,
27+ "model_version" : "1" ,
28+ "inference_type" : "gaze-detection" ,
29+ },
30+ }
31+
1132
1233class InferenceModel :
1334 def __init__ (
@@ -25,13 +46,15 @@ def __init__(
2546 api_key (str): private roboflow api key
2647 version_id (str): the ID of the dataset version to use for inference
2748 """
49+
2850 self .__api_key = api_key
2951 self .id = version_id
3052
31- version_info = self .id .rsplit ("/" )
32- self .dataset_id = version_info [1 ]
33- self .version = version_info [2 ]
34- self .colors = {} if colors is None else colors
53+ if version_id != "BASE_MODEL" :
54+ version_info = self .id .rsplit ("/" )
55+ self .dataset_id = version_info [1 ]
56+ self .version = version_info [2 ]
57+ self .colors = {} if colors is None else colors
3558
3659 def __get_image_params (self , image_path ):
3760 """
@@ -111,3 +134,238 @@ def predict(self, image_path, prediction_type=None, **kwargs):
111134 image_dims = image_dims ,
112135 colors = self .colors ,
113136 )
137+
138+ def predict_video (
139+ self ,
140+ video_path : str ,
141+ fps : int = 5 ,
142+ additional_models : list = [],
143+ prediction_type : str = "batch-video" ,
144+ ) -> List [str ]:
145+ """
146+ Infers detections based on image from specified model and image path.
147+
148+ Args:
149+ video_path (str): path to the video you'd like to perform prediction on
150+ prediction_type (str): type of the model to run
151+ fps (int): frames per second to run inference
152+
153+ Returns:
154+ A list of the signed url and job id
155+
156+ Example:
157+ >>> import roboflow
158+
159+ >>> rf = roboflow.Roboflow(api_key="")
160+
161+ >>> project = rf.workspace().project("PROJECT_ID")
162+
163+ >>> model = project.version("1").model
164+
165+ >>> job_id, signed_url, signed_url_expires = model.predict_video("video.mp4", fps=5, inference_type="object-detection")
166+ """
167+
168+ signed_url_expires = None
169+
170+ url = urljoin (API_URL , "/video_upload_signed_url?api_key=" + self .__api_key )
171+
172+ if fps > 5 :
173+ raise Exception ("FPS must be less than or equal to 5." )
174+
175+ for model in additional_models :
176+ if model not in SUPPORTED_ADDITIONAL_MODELS :
177+ raise Exception (f"Model { model } is not supported for video inference." )
178+
179+ if prediction_type not in SUPPORTED_ROBOFLOW_MODELS :
180+ raise Exception (f"{ prediction_type } is not supported for video inference." )
181+
182+ model_class = self .__class__ .__name__
183+
184+ if model_class == "ObjectDetectionModel" :
185+ self .type = "object-detection"
186+ elif model_class == "ClassificationModel" :
187+ self .type = "classification"
188+ elif model_class == "InstanceSegmentationModel" :
189+ self .type = "instance-segmentation"
190+ elif model_class == "GazeModel" :
191+ self .type = "gaze-detection"
192+ elif model_class == "CLIPModel" :
193+ self .type = "clip-embed-image"
194+ else :
195+ raise Exception ("Model type not supported for video inference." )
196+
197+ payload = json .dumps (
198+ {
199+ "file_name" : os .path .basename (video_path ),
200+ }
201+ )
202+
203+ if not video_path .startswith (("http://" , "https://" )):
204+ headers = {"Content-Type" : "application/json" }
205+
206+ try :
207+ response = requests .request ("POST" , url , headers = headers , data = payload )
208+ except Exception as e :
209+ raise Exception (f"Error uploading video: { e } " )
210+
211+ if not response .ok :
212+ raise Exception (f"Error uploading video: { response .text } " )
213+
214+ signed_url = response .json ()["signed_url" ]
215+
216+ signed_url_expires = (
217+ signed_url .split ("&X-Goog-Expires" )[1 ].split ("&" )[0 ].strip ("=" )
218+ )
219+
220+ # make a POST request to the signed URL
221+ headers = {"Content-Type" : "application/octet-stream" }
222+
223+ try :
224+ with open (video_path , "rb" ) as f :
225+ video_data = f .read ()
226+ except Exception as e :
227+ raise Exception (f"Error reading video: { e } " )
228+
229+ try :
230+ result = requests .put (signed_url , data = video_data , headers = headers )
231+ except Exception as e :
232+ raise Exception (f"There was an error uploading the video: { e } " )
233+
234+ if not result .ok :
235+ raise Exception (
236+ f"There was an error uploading the video: { result .text } "
237+ )
238+ else :
239+ signed_url = video_path
240+
241+ url = urljoin (API_URL , "/videoinfer/?api_key=" + self .__api_key )
242+
243+ if model_class in ("CLIPModel" , "GazeModel" ):
244+ if model_class == "CLIPModel" :
245+ model = "clip"
246+ else :
247+ model = "gaze"
248+
249+ models = [
250+ {
251+ "model_id" : SUPPORTED_ADDITIONAL_MODELS [model ]["model_id" ],
252+ "model_version" : SUPPORTED_ADDITIONAL_MODELS [model ][
253+ "model_version"
254+ ],
255+ "inference_type" : SUPPORTED_ADDITIONAL_MODELS [model ][
256+ "inference_type"
257+ ],
258+ }
259+ ]
260+
261+ for model in additional_models :
262+ models .append (SUPPORTED_ADDITIONAL_MODELS [model ])
263+
264+ payload = json .dumps (
265+ {"input_url" : signed_url , "infer_fps" : 5 , "models" : models }
266+ )
267+
268+ headers = {"Content-Type" : "application/json" }
269+
270+ try :
271+ response = requests .request ("POST" , url , headers = headers , data = payload )
272+ except Exception as e :
273+ raise Exception (f"Error starting video inference: { e } " )
274+
275+ if not response .ok :
276+ raise Exception (f"Error starting video inference: { response .text } " )
277+
278+ job_id = response .json ()["job_id" ]
279+
280+ self .job_id = job_id
281+
282+ return job_id , signed_url , signed_url_expires
283+
284+ def poll_for_video_results (self , job_id : str = None ) -> dict :
285+ """
286+ Polls the Roboflow API to check if video inference is complete.
287+
288+ Returns:
289+ Inference results as a dict
290+
291+ Example:
292+ >>> import roboflow
293+
294+ >>> rf = roboflow.Roboflow(api_key="")
295+
296+ >>> project = rf.workspace().project("PROJECT_ID")
297+
298+ >>> model = project.version("1").model
299+
300+ >>> prediction = model.predict("video.mp4")
301+
302+ >>> results = model.poll_for_video_results()
303+ """
304+
305+ if job_id is None :
306+ job_id = self .job_id
307+
308+ url = urljoin (
309+ API_URL , "/videoinfer/?api_key=" + self .__api_key + "&job_id=" + self .job_id
310+ )
311+
312+ try :
313+ response = requests .get (url , headers = {"Content-Type" : "application/json" })
314+ except Exception as e :
315+ raise Exception (f"Error getting video inference results: { e } " )
316+
317+ if not response .ok :
318+ raise Exception (f"Error getting video inference results: { response .text } " )
319+
320+ data = response .json ()
321+
322+ if data .get ("status" ) != 0 :
323+ return {}
324+
325+ output_signed_url = data ["output_signed_url" ]
326+
327+ inference_data = requests .get (
328+ output_signed_url , headers = {"Content-Type" : "application/json" }
329+ )
330+
331+ # frame_offset and model name are top-level keys
332+ return inference_data .json ()
333+
334+ def poll_until_video_results (self , job_id ) -> dict :
335+ """
336+ Polls the Roboflow API to check if video inference is complete.
337+
338+ When inference is complete, the results are returned.
339+
340+ Returns:
341+ Inference results as a dict
342+
343+ Example:
344+ >>> import roboflow
345+
346+ >>> rf = roboflow.Roboflow(api_key="")
347+
348+ >>> project = rf.workspace().project("PROJECT_ID")
349+
350+ >>> model = project.version("1").model
351+
352+ >>> prediction = model.predict("video.mp4")
353+
354+ >>> results = model.poll_until_results()
355+ """
356+ if job_id is None :
357+ job_id = self .job_id
358+
359+ attempts = 0
360+
361+ while True :
362+ print (f"({ attempts * 60 } s): Checking for inference results" )
363+
364+ response = self .poll_for_video_results ()
365+
366+ time .sleep (60 )
367+
368+ attempts += 1
369+
370+ if response != {}:
371+ return response
0 commit comments