@@ -383,6 +383,15 @@ class CreatePredictionParams(TypedDict):
383383 stream : NotRequired [bool ]
384384 """Enable streaming of prediction output."""
385385
386+ wait : NotRequired [Union [int , bool ]]
387+ """
388+ Wait until the prediction is completed before returning.
389+
390+ If `True`, wait a predetermined number of seconds until the prediction
391+ is completed before returning.
392+ If an `int`, wait for the specified number of seconds.
393+ """
394+
386395 file_encoding_strategy : NotRequired [FileEncodingStrategy ]
387396 """The strategy to use for encoding files in the prediction input."""
388397
@@ -463,6 +472,7 @@ def create( # type: ignore
463472 client = self ._client ,
464473 file_encoding_strategy = file_encoding_strategy ,
465474 )
475+ headers = _create_prediction_headers (wait = params .pop ("wait" , None ))
466476 body = _create_prediction_body (
467477 version ,
468478 input ,
@@ -472,6 +482,7 @@ def create( # type: ignore
472482 resp = self ._client ._request (
473483 "POST" ,
474484 "/v1/predictions" ,
485+ headers = headers ,
475486 json = body ,
476487 )
477488
@@ -554,6 +565,7 @@ async def async_create( # type: ignore
554565 client = self ._client ,
555566 file_encoding_strategy = file_encoding_strategy ,
556567 )
568+ headers = _create_prediction_headers (wait = params .pop ("wait" , None ))
557569 body = _create_prediction_body (
558570 version ,
559571 input ,
@@ -563,6 +575,7 @@ async def async_create( # type: ignore
563575 resp = await self ._client ._async_request (
564576 "POST" ,
565577 "/v1/predictions" ,
578+ headers = headers ,
566579 json = body ,
567580 )
568581
@@ -603,6 +616,20 @@ async def async_cancel(self, id: str) -> Prediction:
603616 return _json_to_prediction (self ._client , resp .json ())
604617
605618
619+ def _create_prediction_headers (
620+ * ,
621+ wait : Optional [Union [int , bool ]] = None ,
622+ ) -> Dict [str , Any ]:
623+ headers = {}
624+
625+ if wait :
626+ if isinstance (wait , bool ):
627+ headers ["Prefer" ] = "wait"
628+ elif isinstance (wait , int ):
629+ headers ["Prefer" ] = f"wait={ wait } "
630+ return headers
631+
632+
606633def _create_prediction_body ( # pylint: disable=too-many-arguments
607634 version : Optional [Union [Version , str ]],
608635 input : Optional [Dict [str , Any ]],
0 commit comments