@@ -185,42 +185,84 @@ describe("Replicate client", () => {
185185 } ) ;
186186
187187 describe ( "predictions.create" , ( ) => {
188- test ( "Calls the correct API route with the correct payload" , async ( ) => {
189- nock ( BASE_URL )
190- . post ( "/predictions" )
191- . reply ( 200 , {
192- id : "ufawqhfynnddngldkgtslldrkq" ,
193- model : "replicate/hello-world" ,
194- version :
195- "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
196- urls : {
197- get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" ,
198- cancel :
199- "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
200- } ,
201- created_at : "2022-04-26T22:13:06.224088Z" ,
202- started_at : null ,
203- completed_at : null ,
204- status : "starting" ,
205- input : {
206- text : "Alice" ,
188+ const predictionTestCases = [
189+ {
190+ description : "String input" ,
191+ input : {
192+ text : "Alice" ,
193+ } ,
194+ } ,
195+ {
196+ description : "Number input" ,
197+ input : {
198+ text : 123 ,
199+ } ,
200+ } ,
201+ {
202+ description : "Boolean input" ,
203+ input : {
204+ text : true ,
205+ } ,
206+ } ,
207+ {
208+ description : "Array input" ,
209+ input : {
210+ text : [ "Alice" , "Bob" , "Charlie" ] ,
211+ } ,
212+ } ,
213+ {
214+ description : "Object input" ,
215+ input : {
216+ text : {
217+ name : "Alice" ,
207218 } ,
208- output : null ,
209- error : null ,
210- logs : null ,
211- metrics : { } ,
212- } ) ;
213- const prediction = await client . predictions . create ( {
219+ } ,
220+ } ,
221+ ] . map ( ( testCase ) => ( {
222+ ...testCase ,
223+ expectedResponse : {
224+ id : "ufawqhfynnddngldkgtslldrkq" ,
225+ model : "replicate/hello-world" ,
214226 version :
215227 "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
216- input : {
217- text : "Alice" ,
228+ urls : {
229+ get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" ,
230+ cancel :
231+ "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
218232 } ,
219- webhook : "http://test.host/webhook" ,
220- webhook_events_filter : [ "output" , "completed" ] ,
221- } ) ;
222- expect ( prediction . id ) . toBe ( "ufawqhfynnddngldkgtslldrkq" ) ;
223- } ) ;
233+ input : testCase . input ,
234+ created_at : "2022-04-26T22:13:06.224088Z" ,
235+ started_at : null ,
236+ completed_at : null ,
237+ status : "starting" ,
238+ } ,
239+ } ) ) ;
240+
241+ test . each ( predictionTestCases ) (
242+ "$description" ,
243+ async ( { input, expectedResponse } ) => {
244+ nock ( BASE_URL )
245+ . post ( "/predictions" , {
246+ version :
247+ "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
248+ input : input as Record < string , any > ,
249+ webhook : "http://test.host/webhook" ,
250+ webhook_events_filter : [ "output" , "completed" ] ,
251+ } )
252+ . reply ( 200 , expectedResponse ) ;
253+
254+ const response = await client . predictions . create ( {
255+ version :
256+ "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
257+ input : input as Record < string , any > ,
258+ webhook : "http://test.host/webhook" ,
259+ webhook_events_filter : [ "output" , "completed" ] ,
260+ } ) ;
261+
262+ expect ( response . input ) . toEqual ( input ) ;
263+ expect ( response . status ) . toBe ( expectedResponse . status ) ;
264+ }
265+ ) ;
224266
225267 const fileTestCases = [
226268 // Skip test case if File type is not available
0 commit comments