@@ -9,7 +9,7 @@ describe('Replicate client', () => {
99
1010 beforeEach ( ( ) => {
1111 client = new Replicate ( { auth : 'test-token' } ) ;
12- client [ 'instance' ] = jest . fn < typeof axios > ( ) ;
12+ client [ 'instance' ] = jest . fn < typeof axios > ( ) ;
1313 } ) ;
1414
1515 describe ( 'constructor' , ( ) => {
@@ -36,7 +36,7 @@ describe('Replicate client', () => {
3636
3737 describe ( 'collections.get' , ( ) => {
3838 test ( 'Calls the correct API route' , async ( ) => {
39- client [ 'instance' ] . mockResolvedValueOnce ( {
39+ client [ 'instance' ] . mockResolvedValueOnce ( {
4040 data : {
4141 name : 'Super resolution' ,
4242 slug : 'super-resolution' ,
@@ -46,7 +46,7 @@ describe('Replicate client', () => {
4646 } ,
4747 } ) ;
4848 const collection = await client . collections . get ( 'super-resolution' ) ;
49- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
49+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
5050 '/collections/super-resolution' ,
5151 {
5252 method : 'GET' ,
@@ -60,7 +60,7 @@ describe('Replicate client', () => {
6060
6161 describe ( 'models.get' , ( ) => {
6262 test ( 'Calls the correct API route' , async ( ) => {
63- client [ 'instance' ] . mockResolvedValueOnce ( {
63+ client [ 'instance' ] . mockResolvedValueOnce ( {
6464 data : {
6565 url : 'https://replicate.com/replicate/hello-world' ,
6666 owner : 'replicate' ,
@@ -77,7 +77,7 @@ describe('Replicate client', () => {
7777 } ,
7878 } ) ;
7979 await client . models . get ( 'replicate' , 'hello-world' ) ;
80- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
80+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
8181 '/models/replicate/hello-world' ,
8282 {
8383 method : 'GET' ,
@@ -90,7 +90,7 @@ describe('Replicate client', () => {
9090
9191 describe ( 'predictions.create' , ( ) => {
9292 test ( 'Calls the correct API route with the correct payload' , async ( ) => {
93- client [ 'instance' ] . mockResolvedValueOnce ( {
93+ client [ 'instance' ] . mockResolvedValueOnce ( {
9494 data : {
9595 id : 'ufawqhfynnddngldkgtslldrkq' ,
9696 version :
@@ -121,11 +121,11 @@ describe('Replicate client', () => {
121121 text : 'Alice' ,
122122 } ,
123123 webhook : 'http://test.host/webhook' ,
124- webhook_events_filter : [ 'output' , 'completed' ] ,
124+ webhook_events_filter : [ 'output' , 'completed' ] ,
125125 } ) ;
126126 expect ( prediction . id ) . toBe ( 'ufawqhfynnddngldkgtslldrkq' ) ;
127127
128- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/predictions' , {
128+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/predictions' , {
129129 method : 'POST' ,
130130 data : {
131131 version :
@@ -134,7 +134,7 @@ describe('Replicate client', () => {
134134 text : 'Alice' ,
135135 } ,
136136 webhook : 'http://test.host/webhook' ,
137- webhook_events_filter : [ 'output' , 'completed' ] ,
137+ webhook_events_filter : [ 'output' , 'completed' ] ,
138138 } ,
139139 } ) ;
140140 } ) ;
@@ -144,7 +144,7 @@ describe('Replicate client', () => {
144144
145145 describe ( 'predictions.get' , ( ) => {
146146 test ( 'Calls the correct API route with the correct payload' , async ( ) => {
147- client [ 'instance' ] . mockResolvedValueOnce ( {
147+ client [ 'instance' ] . mockResolvedValueOnce ( {
148148 data : {
149149 id : 'rrr4z55ocneqzikepnug6xezpe' ,
150150 version :
@@ -178,7 +178,7 @@ describe('Replicate client', () => {
178178 ) ;
179179 expect ( prediction . id ) . toBe ( 'rrr4z55ocneqzikepnug6xezpe' ) ;
180180
181- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
181+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
182182 '/predictions/rrr4z55ocneqzikepnug6xezpe' ,
183183 {
184184 method : 'GET' ,
@@ -191,7 +191,7 @@ describe('Replicate client', () => {
191191
192192 describe ( 'predictions.list' , ( ) => {
193193 test ( 'Calls the correct API route with the correct payload' , async ( ) => {
194- client [ 'instance' ] . mockResolvedValueOnce ( {
194+ client [ 'instance' ] . mockResolvedValueOnce ( {
195195 data : {
196196 next : 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw' ,
197197 previous : null ,
@@ -217,23 +217,23 @@ describe('Replicate client', () => {
217217
218218 const predictions = await client . predictions . list ( ) ;
219219 expect ( predictions . results . length ) . toBe ( 1 ) ;
220- expect ( predictions . results [ 0 ] . id ) . toBe ( 'jpzd7hm5gfcapbfyt4mqytarku' ) ;
220+ expect ( predictions . results [ 0 ] . id ) . toBe ( 'jpzd7hm5gfcapbfyt4mqytarku' ) ;
221221
222- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/predictions' , {
222+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/predictions' , {
223223 method : 'GET' ,
224224 } ) ;
225225 } ) ;
226226
227227 test ( 'Paginates results' , async ( ) => {
228- client [ 'instance' ] . mockResolvedValueOnce ( {
228+ client [ 'instance' ] . mockResolvedValueOnce ( {
229229 data : {
230- results : [ { id : 'ufawqhfynnddngldkgtslldrkq' } ] ,
230+ results : [ { id : 'ufawqhfynnddngldkgtslldrkq' } ] ,
231231 next : 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw' ,
232232 } ,
233233 } ) ;
234- client [ 'instance' ] . mockResolvedValueOnce ( {
234+ client [ 'instance' ] . mockResolvedValueOnce ( {
235235 data : {
236- results : [ { id : 'rrr4z55ocneqzikepnug6xezpe' } ] ,
236+ results : [ { id : 'rrr4z55ocneqzikepnug6xezpe' } ] ,
237237 next : null ,
238238 } ,
239239 } ) ;
@@ -248,10 +248,10 @@ describe('Replicate client', () => {
248248 { id : 'rrr4z55ocneqzikepnug6xezpe' } ,
249249 ] ) ;
250250
251- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/predictions' , {
251+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/predictions' , {
252252 method : 'GET' ,
253253 } ) ;
254- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
254+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
255255 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw' ,
256256 {
257257 method : 'GET' ,
@@ -262,15 +262,129 @@ describe('Replicate client', () => {
262262 // Add more tests for error handling, edge cases, etc.
263263 } ) ;
264264
265+ describe ( 'trainings.create' , ( ) => {
266+ test ( 'Calls the correct API route with the correct payload' , async ( ) => {
267+ client [ 'instance' ] . mockResolvedValueOnce ( {
268+ data : {
269+ "id" : "zz4ibbonubfz7carwiefibzgga" ,
270+ "version" : "{version}" ,
271+ "status" : "starting" ,
272+ "input" : {
273+ "text" : "..."
274+ } ,
275+ "output" : null ,
276+ "error" : null ,
277+ "logs" : null ,
278+ "started_at" : null ,
279+ "created_at" : "2023-03-28T21:47:58.566434Z" ,
280+ "completed_at" : null
281+ }
282+ } ) ;
283+
284+ const training = await client . trainings . create (
285+ 'owner' ,
286+ 'model' ,
287+ '632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532' ,
288+ {
289+ destination : 'new_owner/new_model' ,
290+ input : {
291+ text : '...'
292+ }
293+ }
294+ ) ;
295+ expect ( training . id ) . toBe ( 'zz4ibbonubfz7carwiefibzgga' ) ;
296+
297+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings' , {
298+ method : 'POST' ,
299+ data : {
300+ destination : 'new_owner/new_model' ,
301+ input : {
302+ text : '...'
303+ } ,
304+ }
305+ } ) ;
306+ } ) ;
307+
308+ // Add more tests for error handling, edge cases, etc.
309+ } ) ;
310+
311+ describe ( 'trainings.get' , ( ) => {
312+ test ( 'Calls the correct API route with the correct payload' , async ( ) => {
313+ client [ 'instance' ] . mockResolvedValueOnce ( {
314+ data : {
315+ "id" : "zz4ibbonubfz7carwiefibzgga" ,
316+ "version" : "{version}" ,
317+ "status" : "succeeded" ,
318+ "input" : {
319+ "data" : "..." ,
320+ "param1" : "..."
321+ } ,
322+ "output" : {
323+ "version" : "..."
324+ } ,
325+ "error" : null ,
326+ "logs" : null ,
327+ "webhook_completed" : null ,
328+ "started_at" : null ,
329+ "created_at" : "2023-03-28T21:47:58.566434Z" ,
330+ "completed_at" : null
331+ }
332+ } ) ;
333+
334+ const training = await client . trainings . get ( 'zz4ibbonubfz7carwiefibzgga' ) ;
335+ expect ( training . status ) . toBe ( 'succeeded' ) ;
336+
337+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/trainings/zz4ibbonubfz7carwiefibzgga' , {
338+ method : 'GET' ,
339+ } ) ;
340+ } ) ;
341+
342+ // Add more tests for error handling, edge cases, etc.
343+ } ) ;
344+
345+ describe ( 'trainings.cancel' , ( ) => {
346+ test ( 'Calls the correct API route with the correct payload' , async ( ) => {
347+ client [ 'instance' ] . mockResolvedValueOnce ( {
348+ data : {
349+ "id" : "zz4ibbonubfz7carwiefibzgga" ,
350+ "version" : "{version}" ,
351+ "status" : "canceled" ,
352+ "input" : {
353+ "data" : "..." ,
354+ "param1" : "..."
355+ } ,
356+ "output" : {
357+ "version" : "..."
358+ } ,
359+ "error" : null ,
360+ "logs" : null ,
361+ "webhook_completed" : null ,
362+ "started_at" : null ,
363+ "created_at" : "2023-03-28T21:47:58.566434Z" ,
364+ "completed_at" : null
365+ }
366+ } ) ;
367+
368+ const training = await client . trainings . cancel ( "zz4ibbonubfz7carwiefibzgga" ) ;
369+ expect ( training . status ) . toBe ( 'canceled' ) ;
370+
371+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/trainings/zz4ibbonubfz7carwiefibzgga/cancel' , {
372+ method : 'POST' ,
373+ } ) ;
374+ } ) ;
375+
376+ // Add more tests for error handling, edge cases, etc.
377+ } ) ;
378+
265379 describe ( 'run' , ( ) => {
266380 test ( 'Calls the correct API routes' , async ( ) => {
267- client [ 'instance' ] . mockResolvedValueOnce ( {
381+ client [ 'instance' ] . mockResolvedValueOnce ( {
268382 data : {
269383 id : 'ufawqhfynnddngldkgtslldrkq' ,
270384 status : 'processing' ,
271385 } ,
272386 } ) ;
273- client [ 'instance' ] . mockResolvedValueOnce ( {
387+ client [ 'instance' ] . mockResolvedValueOnce ( {
274388 data : {
275389 id : 'ufawqhfynnddngldkgtslldrkq' ,
276390 status : 'succeeded' ,
@@ -283,7 +397,7 @@ describe('Replicate client', () => {
283397 input : { text : 'Hello, world!' } ,
284398 }
285399 ) ;
286- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/predictions' , {
400+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith ( '/predictions' , {
287401 method : 'POST' ,
288402 data : {
289403 version :
@@ -293,7 +407,7 @@ describe('Replicate client', () => {
293407 } ,
294408 } ,
295409 } ) ;
296- expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
410+ expect ( client [ 'instance' ] ) . toHaveBeenCalledWith (
297411 '/predictions/ufawqhfynnddngldkgtslldrkq' ,
298412 {
299413 method : 'GET' ,
0 commit comments