Skip to content

Commit 7e80fee

Browse files
committed
Support passing AbortSignal into paginate()
1 parent 4a10c04 commit 7e80fee

File tree

3 files changed

+97
-5
lines changed

3 files changed

+97
-5
lines changed

index.d.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@ declare module "replicate" {
187187
}
188188
): Promise<Response>;
189189

190-
paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[T]>;
190+
paginate<T>(
191+
endpoint: () => Promise<Page<T>>,
192+
options?: { signal?: AbortSignal }
193+
): AsyncGenerator<T[]>;
191194

192195
wait(
193196
prediction: Prediction,

index.js

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,15 +356,20 @@ class Replicate {
356356
* console.log(page);
357357
* }
358358
* @param {Function} endpoint - Function that returns a promise for the next page of results
359+
* @param {object} [options]
360+
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the request.
359361
* @yields {object[]} Each page of results
360362
*/
361-
async *paginate(endpoint) {
363+
async *paginate(endpoint, options = {}) {
362364
const response = await endpoint();
363365
yield response.results;
364-
if (response.next) {
366+
if (response.next && !(options.signal && options.signal.aborted)) {
365367
const nextPage = () =>
366-
this.request(response.next, { method: "GET" }).then((r) => r.json());
367-
yield* this.paginate(nextPage);
368+
this.request(response.next, {
369+
method: "GET",
370+
signal: options.signal,
371+
}).then((r) => r.json());
372+
yield* this.paginate(nextPage, options);
368373
}
369374
}
370375

index.test.ts

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,90 @@ describe("Replicate client", () => {
9898
});
9999
});
100100

101+
describe("paginate", () => {
102+
test("pages through results", async () => {
103+
nock(BASE_URL)
104+
.get("/collections")
105+
.reply(200, {
106+
results: [
107+
{
108+
name: "Super resolution",
109+
slug: "super-resolution",
110+
description:
111+
"Upscaling models that create high-quality images from low-quality images.",
112+
},
113+
],
114+
next: `${BASE_URL}/collections?page=2`,
115+
previous: null,
116+
});
117+
nock(BASE_URL)
118+
.get("/collections?page=2")
119+
.reply(200, {
120+
results: [
121+
{
122+
name: "Image classification",
123+
slug: "image-classification",
124+
description: "Models that classify images.",
125+
},
126+
],
127+
next: null,
128+
previous: null,
129+
});
130+
131+
const iterator = client.paginate(client.collections.list);
132+
133+
const firstPage = (await iterator.next()).value;
134+
expect(firstPage.length).toBe(1);
135+
136+
const secondPage = (await iterator.next()).value;
137+
expect(secondPage.length).toBe(1);
138+
});
139+
140+
test("accepts an abort signal", async () => {
141+
nock(BASE_URL)
142+
.get("/collections")
143+
.reply(200, {
144+
results: [
145+
{
146+
name: "Super resolution",
147+
slug: "super-resolution",
148+
description:
149+
"Upscaling models that create high-quality images from low-quality images.",
150+
},
151+
],
152+
next: `${BASE_URL}/collections?page=2`,
153+
previous: null,
154+
});
155+
nock(BASE_URL)
156+
.get("/collections?page=2")
157+
.reply(200, {
158+
results: [
159+
{
160+
name: "Image classification",
161+
slug: "image-classification",
162+
description: "Models that classify images.",
163+
},
164+
],
165+
next: null,
166+
previous: null,
167+
});
168+
169+
const controller = new AbortController();
170+
const iterator = client.paginate(client.collections.list, {
171+
signal: controller.signal,
172+
});
173+
174+
const firstIteration = await iterator.next();
175+
expect(firstIteration.value.length).toBe(1);
176+
177+
controller.abort();
178+
179+
const secondIteration = await iterator.next();
180+
expect(secondIteration.value).toBeUndefined();
181+
expect(secondIteration.done).toBe(true);
182+
});
183+
});
184+
101185
describe("account.get", () => {
102186
test("Calls the correct API route", async () => {
103187
nock(BASE_URL).get("/account").reply(200, {

0 commit comments

Comments
 (0)