1515
1616from replicate import identifier
1717from replicate .exceptions import ReplicateError
18+ from replicate .helpers import transform_output
1819
1920try :
2021 from pydantic import v1 as pydantic # type: ignore
@@ -62,10 +63,19 @@ class EventSource:
6263 A server-sent event source.
6364 """
6465
66+ client : "Client"
6567 response : "httpx.Response"
66-
67- def __init__ (self , response : "httpx.Response" ) -> None :
68+ use_file_output : bool
69+
70+ def __init__ (
71+ self ,
72+ client : "Client" ,
73+ response : "httpx.Response" ,
74+ use_file_output : Optional [bool ] = None ,
75+ ) -> None :
76+ self .client = client
6877 self .response = response
78+ self .use_file_output = use_file_output or False
6979 content_type , _ , _ = response .headers ["content-type" ].partition (";" )
7080 if content_type != "text/event-stream" :
7181 raise ValueError (
@@ -147,6 +157,12 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
147157 if sse .event == ServerSentEvent .EventType .ERROR :
148158 raise RuntimeError (sse .data )
149159
160+ if (
161+ self .use_file_output
162+ and sse .event == ServerSentEvent .EventType .OUTPUT
163+ ):
164+ sse .data = transform_output (sse .data , client = self .client )
165+
150166 yield sse
151167
152168 if sse .event == ServerSentEvent .EventType .DONE :
@@ -161,6 +177,12 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
161177 if sse .event == ServerSentEvent .EventType .ERROR :
162178 raise RuntimeError (sse .data )
163179
180+ if (
181+ self .use_file_output
182+ and sse .event == ServerSentEvent .EventType .OUTPUT
183+ ):
184+ sse .data = transform_output (sse .data , client = self .client )
185+
164186 yield sse
165187
166188 if sse .event == ServerSentEvent .EventType .DONE :
@@ -171,6 +193,7 @@ def stream(
171193 client : "Client" ,
172194 ref : Union ["Model" , "Version" , "ModelVersionIdentifier" , str ],
173195 input : Optional [Dict [str , Any ]] = None ,
196+ use_file_output : Optional [bool ] = None ,
174197 ** params : Unpack ["Predictions.CreatePredictionParams" ],
175198) -> Iterator [ServerSentEvent ]:
176199 """
@@ -204,13 +227,14 @@ def stream(
204227 headers ["Cache-Control" ] = "no-store"
205228
206229 with client ._client .stream ("GET" , url , headers = headers ) as response :
207- yield from EventSource (response )
230+ yield from EventSource (client , response , use_file_output = use_file_output )
208231
209232
210233async def async_stream (
211234 client : "Client" ,
212235 ref : Union ["Model" , "Version" , "ModelVersionIdentifier" , str ],
213236 input : Optional [Dict [str , Any ]] = None ,
237+ use_file_output : Optional [bool ] = None ,
214238 ** params : Unpack ["Predictions.CreatePredictionParams" ],
215239) -> AsyncIterator [ServerSentEvent ]:
216240 """
@@ -244,7 +268,9 @@ async def async_stream(
244268 headers ["Cache-Control" ] = "no-store"
245269
246270 async with client ._async_client .stream ("GET" , url , headers = headers ) as response :
247- async for event in EventSource (response ):
271+ async for event in EventSource (
272+ client , response , use_file_output = use_file_output
273+ ):
248274 yield event
249275
250276
0 commit comments