1+ import asyncio
12import re
23import time
34from dataclasses import dataclass
@@ -114,6 +115,7 @@ def progress(self) -> Optional[Progress]:
114115 """
115116 The progress of the prediction, if available.
116117 """
118+
117119 if self .logs is None or self .logs == "" :
118120 return None
119121
@@ -123,10 +125,20 @@ def wait(self) -> None:
123125 """
124126 Wait for prediction to finish.
125127 """
128+
126129 while self .status not in ["succeeded" , "failed" , "canceled" ]:
127130 time .sleep (self ._client .poll_interval )
128131 self .reload ()
129132
133+ async def async_wait (self ) -> None :
134+ """
135+ Wait for prediction to finish asynchronously.
136+ """
137+
138+ while self .status not in ["succeeded" , "failed" , "canceled" ]:
139+ await asyncio .sleep (self ._client .poll_interval )
140+ await self .async_reload ()
141+
130142 def stream (self ) -> Optional [Iterator ["ServerSentEvent" ]]:
131143 """
132144 Stream the prediction output.
@@ -164,6 +176,15 @@ def reload(self) -> None:
164176 for name , value in updated .dict ().items ():
165177 setattr (self , name , value )
166178
179+ async def async_reload (self ) -> None :
180+ """
181+ Load this prediction from the server asynchronously.
182+ """
183+
184+ updated = await self ._client .predictions .async_get (self .id )
185+ for name , value in updated .dict ().items ():
186+ setattr (self , name , value )
187+
167188 def output_iterator (self ) -> Iterator [Any ]:
168189 """
169190 Return an iterator of the prediction output.
0 commit comments