Skip to content

Commit 68da42d

Browse files
authored
Add progress callback parameter to run method (#107)
* Fix bug in implementation of wait, where stop isn't called after initial update * Add progress parameter to run method
1 parent fbc3bf2 commit 68da42d

File tree

3 files changed

+68
-14
lines changed

3 files changed

+68
-14
lines changed

index.d.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ declare module 'replicate' {
8989
webhook?: string;
9090
webhook_events_filter?: WebhookEventType[];
9191
signal?: AbortSignal;
92-
}
92+
},
93+
progress?: (Prediction) => void
9394
): Promise<object>;
9495

9596
request(route: string | URL, options: {

index.js

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@ class Replicate {
8686
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
8787
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
8888
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
89+
* @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed.
8990
* @throws {Error} If the prediction failed
9091
* @returns {Promise<object>} - Resolves with the output of running the model
9192
*/
92-
async run(identifier, options) {
93+
async run(identifier, options, progress) {
9394
const { wait, ...data } = options;
9495

9596
// Define a pattern for owner and model names that allows
@@ -117,17 +118,32 @@ class Replicate {
117118
version,
118119
});
119120

121+
// Call progress callback with the initial prediction object
122+
if (progress) {
123+
progress(prediction);
124+
}
125+
120126
const { signal } = options;
121127

122-
prediction = await this.wait(prediction, wait || {}, async ({ id }) => {
128+
prediction = await this.wait(prediction, wait || {}, async (updatedPrediction) => {
129+
// Call progress callback with the updated prediction object
130+
if (progress) {
131+
progress(updatedPrediction);
132+
}
133+
123134
if (signal && signal.aborted) {
124-
await this.predictions.cancel(id);
135+
await this.predictions.cancel(updatedPrediction.id);
125136
return true; // stop polling
126137
}
127138

128139
return false; // continue polling
129140
});
130141

142+
// Call progress callback with the completed prediction object
143+
if (progress) {
144+
progress(prediction);
145+
}
146+
131147
if (prediction.status === 'failed') {
132148
throw new Error(`Prediction failed: ${prediction.error}`);
133149
}
@@ -252,33 +268,34 @@ class Replicate {
252268
return prediction;
253269
}
254270

255-
let updatedPrediction = await this.predictions.get(id);
256-
257271
// eslint-disable-next-line no-promise-executor-return
258272
const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms));
259273

260274
let attempts = 0;
261275
const interval = options.interval || 250;
262276
const max_attempts = options.max_attempts || null;
263277

278+
let updatedPrediction = await this.predictions.get(id);
279+
264280
while (
265281
updatedPrediction.status !== 'succeeded' &&
266282
updatedPrediction.status !== 'failed' &&
267283
updatedPrediction.status !== 'canceled'
268284
) {
285+
/* eslint-disable no-await-in-loop */
286+
if (stop && await stop(updatedPrediction) === true) {
287+
break;
288+
}
289+
269290
attempts += 1;
270291
if (max_attempts && attempts > max_attempts) {
271292
throw new Error(
272293
`Prediction ${id} did not finish after ${max_attempts} attempts`
273294
);
274295
}
275296

276-
/* eslint-disable no-await-in-loop */
277297
await sleep(interval);
278298
updatedPrediction = await this.predictions.get(prediction.id);
279-
if (stop && await stop(updatedPrediction) === true) {
280-
break;
281-
}
282299
/* eslint-enable no-await-in-loop */
283300
}
284301

index.test.ts

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,26 +517,62 @@ describe('Replicate client', () => {
517517

518518
describe('run', () => {
519519
test('Calls the correct API routes', async () => {
520+
let firstPollingRequest = true;
521+
520522
nock(BASE_URL)
521523
.post('/predictions')
524+
.reply(201, {
525+
id: 'ufawqhfynnddngldkgtslldrkq',
526+
status: 'starting',
527+
})
528+
.get('/predictions/ufawqhfynnddngldkgtslldrkq')
529+
.twice()
522530
.reply(200, {
523531
id: 'ufawqhfynnddngldkgtslldrkq',
524532
status: 'processing',
525533
})
526534
.get('/predictions/ufawqhfynnddngldkgtslldrkq')
527-
.reply(201, {
535+
.reply(200, {
528536
id: 'ufawqhfynnddngldkgtslldrkq',
529537
status: 'succeeded',
530-
output: 'foobar',
538+
output: 'Goodbye!',
531539
});
532540

541+
const progress = jest.fn();
542+
533543
const output = await client.run(
534544
'owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa',
535545
{
536546
input: { text: 'Hello, world!' },
537-
}
547+
wait: { interval: 1 }
548+
},
549+
progress
538550
);
539-
expect(output).toBe('foobar');
551+
552+
expect(output).toBe('Goodbye!');
553+
554+
expect(progress).toHaveBeenNthCalledWith(1, {
555+
id: 'ufawqhfynnddngldkgtslldrkq',
556+
status: 'starting',
557+
});
558+
559+
expect(progress).toHaveBeenNthCalledWith(2, {
560+
id: 'ufawqhfynnddngldkgtslldrkq',
561+
status: 'processing',
562+
});
563+
564+
expect(progress).toHaveBeenNthCalledWith(3, {
565+
id: 'ufawqhfynnddngldkgtslldrkq',
566+
status: 'processing',
567+
});
568+
569+
expect(progress).toHaveBeenNthCalledWith(4, {
570+
id: 'ufawqhfynnddngldkgtslldrkq',
571+
status: 'succeeded',
572+
output: 'Goodbye!',
573+
});
574+
575+
expect(progress).toHaveBeenCalledTimes(4);
540576
});
541577

542578
test('Does not throw an error for identifier containing hyphen and full stop', async () => {

0 commit comments

Comments
 (0)