Skip to content

Commit f2f0420

Browse files
authored
Add support for training endpoints (#35)
* Add support for training endpoints * Specify string format for destination parameter * Update type definition
1 parent 4da0d9e commit f2f0420

File tree

4 files changed

+214
-24
lines changed

4 files changed

+214
-24
lines changed

index.d.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,21 @@ declare module 'replicate' {
113113
get(prediction_id: string): Promise<Prediction>;
114114
list(): Promise<Page<Prediction>>;
115115
};
116+
117+
trainings: {
118+
create(
119+
model_owner: string,
120+
model_name: string,
121+
version_id: string,
122+
options: {
123+
destination: `${string}/${string}`;
124+
input: object;
125+
webhook?: string;
126+
webhook_events_filter?: WebhookEventType[];
127+
}
128+
): Promise<Training>;
129+
get(options: TrainingsGetOptions): Promise<Training>;
130+
cancel(options: TrainingsGetOptions): Promise<Training>;
131+
};
116132
}
117133
}

index.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ const axios = require('axios');
33
const collections = require('./lib/collections');
44
const models = require('./lib/models');
55
const predictions = require('./lib/predictions');
6+
const trainings = require('./lib/trainings');
67
const packageJSON = require('./package.json');
78

89
/**
@@ -63,6 +64,12 @@ class Replicate {
6364
get: predictions.get.bind(this),
6465
list: predictions.list.bind(this),
6566
};
67+
68+
this.trainings = {
69+
create: trainings.create.bind(this),
70+
get: trainings.get.bind(this),
71+
cancel: trainings.cancel.bind(this),
72+
};
6673
}
6774

6875
/**

index.test.ts

Lines changed: 138 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ describe('Replicate client', () => {
99

1010
beforeEach(() => {
1111
client = new Replicate({ auth: 'test-token' });
12-
client['instance'] = jest.fn<typeof axios>();
12+
client[ 'instance' ] = jest.fn<typeof axios>();
1313
});
1414

1515
describe('constructor', () => {
@@ -36,7 +36,7 @@ describe('Replicate client', () => {
3636

3737
describe('collections.get', () => {
3838
test('Calls the correct API route', async () => {
39-
client['instance'].mockResolvedValueOnce({
39+
client[ 'instance' ].mockResolvedValueOnce({
4040
data: {
4141
name: 'Super resolution',
4242
slug: 'super-resolution',
@@ -46,7 +46,7 @@ describe('Replicate client', () => {
4646
},
4747
});
4848
const collection = await client.collections.get('super-resolution');
49-
expect(client['instance']).toHaveBeenCalledWith(
49+
expect(client[ 'instance' ]).toHaveBeenCalledWith(
5050
'/collections/super-resolution',
5151
{
5252
method: 'GET',
@@ -60,7 +60,7 @@ describe('Replicate client', () => {
6060

6161
describe('models.get', () => {
6262
test('Calls the correct API route', async () => {
63-
client['instance'].mockResolvedValueOnce({
63+
client[ 'instance' ].mockResolvedValueOnce({
6464
data: {
6565
url: 'https://replicate.com/replicate/hello-world',
6666
owner: 'replicate',
@@ -77,7 +77,7 @@ describe('Replicate client', () => {
7777
},
7878
});
7979
await client.models.get('replicate', 'hello-world');
80-
expect(client['instance']).toHaveBeenCalledWith(
80+
expect(client[ 'instance' ]).toHaveBeenCalledWith(
8181
'/models/replicate/hello-world',
8282
{
8383
method: 'GET',
@@ -90,7 +90,7 @@ describe('Replicate client', () => {
9090

9191
describe('predictions.create', () => {
9292
test('Calls the correct API route with the correct payload', async () => {
93-
client['instance'].mockResolvedValueOnce({
93+
client[ 'instance' ].mockResolvedValueOnce({
9494
data: {
9595
id: 'ufawqhfynnddngldkgtslldrkq',
9696
version:
@@ -121,11 +121,11 @@ describe('Replicate client', () => {
121121
text: 'Alice',
122122
},
123123
webhook: 'http://test.host/webhook',
124-
webhook_events_filter: ['output', 'completed'],
124+
webhook_events_filter: [ 'output', 'completed' ],
125125
});
126126
expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq');
127127

128-
expect(client['instance']).toHaveBeenCalledWith('/predictions', {
128+
expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', {
129129
method: 'POST',
130130
data: {
131131
version:
@@ -134,7 +134,7 @@ describe('Replicate client', () => {
134134
text: 'Alice',
135135
},
136136
webhook: 'http://test.host/webhook',
137-
webhook_events_filter: ['output', 'completed'],
137+
webhook_events_filter: [ 'output', 'completed' ],
138138
},
139139
});
140140
});
@@ -144,7 +144,7 @@ describe('Replicate client', () => {
144144

145145
describe('predictions.get', () => {
146146
test('Calls the correct API route with the correct payload', async () => {
147-
client['instance'].mockResolvedValueOnce({
147+
client[ 'instance' ].mockResolvedValueOnce({
148148
data: {
149149
id: 'rrr4z55ocneqzikepnug6xezpe',
150150
version:
@@ -178,7 +178,7 @@ describe('Replicate client', () => {
178178
);
179179
expect(prediction.id).toBe('rrr4z55ocneqzikepnug6xezpe');
180180

181-
expect(client['instance']).toHaveBeenCalledWith(
181+
expect(client[ 'instance' ]).toHaveBeenCalledWith(
182182
'/predictions/rrr4z55ocneqzikepnug6xezpe',
183183
{
184184
method: 'GET',
@@ -191,7 +191,7 @@ describe('Replicate client', () => {
191191

192192
describe('predictions.list', () => {
193193
test('Calls the correct API route with the correct payload', async () => {
194-
client['instance'].mockResolvedValueOnce({
194+
client[ 'instance' ].mockResolvedValueOnce({
195195
data: {
196196
next: 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
197197
previous: null,
@@ -217,23 +217,23 @@ describe('Replicate client', () => {
217217

218218
const predictions = await client.predictions.list();
219219
expect(predictions.results.length).toBe(1);
220-
expect(predictions.results[0].id).toBe('jpzd7hm5gfcapbfyt4mqytarku');
220+
expect(predictions.results[ 0 ].id).toBe('jpzd7hm5gfcapbfyt4mqytarku');
221221

222-
expect(client['instance']).toHaveBeenCalledWith('/predictions', {
222+
expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', {
223223
method: 'GET',
224224
});
225225
});
226226

227227
test('Paginates results', async () => {
228-
client['instance'].mockResolvedValueOnce({
228+
client[ 'instance' ].mockResolvedValueOnce({
229229
data: {
230-
results: [{ id: 'ufawqhfynnddngldkgtslldrkq' }],
230+
results: [ { id: 'ufawqhfynnddngldkgtslldrkq' } ],
231231
next: 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
232232
},
233233
});
234-
client['instance'].mockResolvedValueOnce({
234+
client[ 'instance' ].mockResolvedValueOnce({
235235
data: {
236-
results: [{ id: 'rrr4z55ocneqzikepnug6xezpe' }],
236+
results: [ { id: 'rrr4z55ocneqzikepnug6xezpe' } ],
237237
next: null,
238238
},
239239
});
@@ -248,10 +248,10 @@ describe('Replicate client', () => {
248248
{ id: 'rrr4z55ocneqzikepnug6xezpe' },
249249
]);
250250

251-
expect(client['instance']).toHaveBeenCalledWith('/predictions', {
251+
expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', {
252252
method: 'GET',
253253
});
254-
expect(client['instance']).toHaveBeenCalledWith(
254+
expect(client[ 'instance' ]).toHaveBeenCalledWith(
255255
'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
256256
{
257257
method: 'GET',
@@ -262,15 +262,129 @@ describe('Replicate client', () => {
262262
// Add more tests for error handling, edge cases, etc.
263263
});
264264

265+
describe('trainings.create', () => {
266+
test('Calls the correct API route with the correct payload', async () => {
267+
client[ 'instance' ].mockResolvedValueOnce({
268+
data: {
269+
"id": "zz4ibbonubfz7carwiefibzgga",
270+
"version": "{version}",
271+
"status": "starting",
272+
"input": {
273+
"text": "..."
274+
},
275+
"output": null,
276+
"error": null,
277+
"logs": null,
278+
"started_at": null,
279+
"created_at": "2023-03-28T21:47:58.566434Z",
280+
"completed_at": null
281+
}
282+
});
283+
284+
const training = await client.trainings.create(
285+
'owner',
286+
'model',
287+
'632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532',
288+
{
289+
destination: 'new_owner/new_model',
290+
input: {
291+
text: '...'
292+
}
293+
}
294+
);
295+
expect(training.id).toBe('zz4ibbonubfz7carwiefibzgga');
296+
297+
expect(client[ 'instance' ]).toHaveBeenCalledWith('/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings', {
298+
method: 'POST',
299+
data: {
300+
destination: 'new_owner/new_model',
301+
input: {
302+
text: '...'
303+
},
304+
}
305+
});
306+
});
307+
308+
// Add more tests for error handling, edge cases, etc.
309+
});
310+
311+
describe('trainings.get', () => {
312+
test('Calls the correct API route with the correct payload', async () => {
313+
client[ 'instance' ].mockResolvedValueOnce({
314+
data: {
315+
"id": "zz4ibbonubfz7carwiefibzgga",
316+
"version": "{version}",
317+
"status": "succeeded",
318+
"input": {
319+
"data": "...",
320+
"param1": "..."
321+
},
322+
"output": {
323+
"version": "..."
324+
},
325+
"error": null,
326+
"logs": null,
327+
"webhook_completed": null,
328+
"started_at": null,
329+
"created_at": "2023-03-28T21:47:58.566434Z",
330+
"completed_at": null
331+
}
332+
});
333+
334+
const training = await client.trainings.get('zz4ibbonubfz7carwiefibzgga');
335+
expect(training.status).toBe('succeeded');
336+
337+
expect(client[ 'instance' ]).toHaveBeenCalledWith('/trainings/zz4ibbonubfz7carwiefibzgga', {
338+
method: 'GET',
339+
});
340+
});
341+
342+
// Add more tests for error handling, edge cases, etc.
343+
});
344+
345+
describe('trainings.cancel', () => {
346+
test('Calls the correct API route with the correct payload', async () => {
347+
client[ 'instance' ].mockResolvedValueOnce({
348+
data: {
349+
"id": "zz4ibbonubfz7carwiefibzgga",
350+
"version": "{version}",
351+
"status": "canceled",
352+
"input": {
353+
"data": "...",
354+
"param1": "..."
355+
},
356+
"output": {
357+
"version": "..."
358+
},
359+
"error": null,
360+
"logs": null,
361+
"webhook_completed": null,
362+
"started_at": null,
363+
"created_at": "2023-03-28T21:47:58.566434Z",
364+
"completed_at": null
365+
}
366+
});
367+
368+
const training = await client.trainings.cancel("zz4ibbonubfz7carwiefibzgga");
369+
expect(training.status).toBe('canceled');
370+
371+
expect(client[ 'instance' ]).toHaveBeenCalledWith('/trainings/zz4ibbonubfz7carwiefibzgga/cancel', {
372+
method: 'POST',
373+
});
374+
});
375+
376+
// Add more tests for error handling, edge cases, etc.
377+
});
378+
265379
describe('run', () => {
266380
test('Calls the correct API routes', async () => {
267-
client['instance'].mockResolvedValueOnce({
381+
client[ 'instance' ].mockResolvedValueOnce({
268382
data: {
269383
id: 'ufawqhfynnddngldkgtslldrkq',
270384
status: 'processing',
271385
},
272386
});
273-
client['instance'].mockResolvedValueOnce({
387+
client[ 'instance' ].mockResolvedValueOnce({
274388
data: {
275389
id: 'ufawqhfynnddngldkgtslldrkq',
276390
status: 'succeeded',
@@ -283,7 +397,7 @@ describe('Replicate client', () => {
283397
input: { text: 'Hello, world!' },
284398
}
285399
);
286-
expect(client['instance']).toHaveBeenCalledWith('/predictions', {
400+
expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', {
287401
method: 'POST',
288402
data: {
289403
version:
@@ -293,7 +407,7 @@ describe('Replicate client', () => {
293407
},
294408
},
295409
});
296-
expect(client['instance']).toHaveBeenCalledWith(
410+
expect(client[ 'instance' ]).toHaveBeenCalledWith(
297411
'/predictions/ufawqhfynnddngldkgtslldrkq',
298412
{
299413
method: 'GET',

lib/trainings.js

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/**
2+
* Create a new training
3+
*
4+
* @param {string} model_owner - Required. The username of the user or organization who owns the model
5+
* @param {string} model_name - Required. The name of the model
6+
* @param {string} version_id - Required. The version ID
7+
* @param {object} options
8+
* @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}"
9+
* @param {object} options.input - Required. An object with the model inputs
10+
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates
11+
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
12+
* @returns {Promise<object>} Resolves with the data for the created training
13+
*/
14+
async function createTraining(model_owner, model_name, version_id, options) {
15+
const { ...data } = options;
16+
17+
const training = this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, {
18+
method: 'POST',
19+
data,
20+
});
21+
22+
return training;
23+
}
24+
25+
/**
26+
* Fetch a training by ID
27+
*
28+
* @param {string} training_id - Required. The training ID
29+
* @returns {Promise<object>} Resolves with the data for the training
30+
*/
31+
async function getTraining(training_id) {
32+
return this.request(`/trainings/${training_id}`, {
33+
method: 'GET',
34+
});
35+
}
36+
37+
/**
38+
* Cancel a training by ID
39+
*
40+
* @param {string} training_id - Required. The training ID
41+
* @returns {Promise<object>} Resolves with the data for the training
42+
*/
43+
async function cancelTraining(training_id) {
44+
return this.request(`/trainings/${training_id}/cancel`, {
45+
method: 'POST',
46+
});
47+
}
48+
49+
module.exports = {
50+
create: createTraining,
51+
get: getTraining,
52+
cancel: cancelTraining,
53+
};

0 commit comments

Comments
 (0)