File tree Expand file tree Collapse file tree 2 files changed +53
-0
lines changed Expand file tree Collapse file tree 2 files changed +53
-0
lines changed Original file line number Diff line number Diff line change @@ -129,6 +129,8 @@ def create( # type: ignore
129129 webhook : Optional [str ] = None ,
130130 webhook_completed : Optional [str ] = None ,
131131 webhook_events_filter : Optional [List [str ]] = None ,
132+ * ,
133+ stream : Optional [bool ] = None ,
132134 ** kwargs ,
133135 ) -> Prediction :
134136 """
@@ -157,6 +159,8 @@ def create( # type: ignore
157159 body ["webhook_completed" ] = webhook_completed
158160 if webhook_events_filter is not None :
159161 body ["webhook_events_filter" ] = webhook_events_filter
162+ if stream is True :
163+ body ["stream" ] = "true"
160164
161165 resp = self ._client ._request (
162166 "POST" ,
Original file line number Diff line number Diff line change @@ -94,6 +94,55 @@ def test_cancel():
9494 assert rsp .call_count == 1
9595
9696
97+ @responses .activate
98+ def test_stream ():
99+ client = create_client ()
100+ version = create_version (client )
101+
102+ rsp = responses .post (
103+ "https://api.replicate.com/v1/predictions" ,
104+ match = [
105+ matchers .json_params_matcher (
106+ {
107+ "version" : "v1" ,
108+ "input" : {"text" : "world" },
109+ "stream" : "true" ,
110+ }
111+ ),
112+ ],
113+ json = {
114+ "id" : "p1" ,
115+ "version" : "v1" ,
116+ "urls" : {
117+ "get" : "https://api.replicate.com/v1/predictions/p1" ,
118+ "cancel" : "https://api.replicate.com/v1/predictions/p1/cancel" ,
119+ "stream" : "https://streaming.api.replicate.com/v1/predictions/p1" ,
120+ },
121+ "created_at" : "2022-04-26T20:00:40.658234Z" ,
122+ "completed_at" : "2022-04-26T20:02:27.648305Z" ,
123+ "source" : "api" ,
124+ "status" : "processing" ,
125+ "input" : {"text" : "world" },
126+ "output" : None ,
127+ "error" : None ,
128+ "logs" : "" ,
129+ },
130+ )
131+
132+ prediction = client .predictions .create (
133+ version = version ,
134+ input = {"text" : "world" },
135+ stream = True ,
136+ )
137+
138+ assert rsp .call_count == 1
139+
140+ assert (
141+ prediction .urls ["stream" ]
142+ == "https://streaming.api.replicate.com/v1/predictions/p1"
143+ )
144+
145+
97146@responses .activate
98147def test_async_timings ():
99148 client = create_client ()
You can’t perform that action at this time.
0 commit comments