Skip to content

Commit 70c54fe

Browse files
authored
Add replicate.models.list method (#142)
1 parent f828839 commit 70c54fe

File tree

5 files changed

+79
-7
lines changed

5 files changed

+79
-7
lines changed

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,41 @@ const response = await replicate.models.get(model_owner, model_name);
201201
}
202202
```
203203

204+
### `replicate.models.list`
205+
206+
Get a paginated list of all public models.
207+
208+
```js
209+
const response = await replicate.models.list();
210+
```
211+
212+
```jsonc
213+
{
214+
"next": null,
215+
"previous": null,
216+
"results": [
217+
{
218+
"url": "https://replicate.com/replicate/hello-world",
219+
"owner": "replicate",
220+
"name": "hello-world",
221+
"description": "A tiny model that says hello",
222+
"visibility": "public",
223+
"github_url": "https://github.com/replicate/cog-examples",
224+
"paper_url": null,
225+
"license_url": null,
226+
"run_count": 5681081,
227+
"cover_image_url": "...",
228+
"default_example": {
229+
/* ... */
230+
},
231+
"latest_version": {
232+
/* ... */
233+
}
234+
}
235+
]
236+
}
237+
```
238+
204239
### `replicate.models.versions.list`
205240

206241
Get a list of all published versions of a model, including input and output schemas for each version.

index.d.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ declare module 'replicate' {
117117

118118
models: {
119119
get(model_owner: string, model_name: string): Promise<Model>;
120+
list(): Promise<Page<Model>>;
120121
versions: {
121122
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
122123
get(

index.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Replicate {
5151

5252
this.models = {
5353
get: models.get.bind(this),
54+
list: models.list.bind(this),
5455
versions: {
5556
list: models.versions.list.bind(this),
5657
get: models.versions.get.bind(this),

index.test.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { expect, jest, test } from '@jest/globals';
2-
import Replicate, { ApiError, Prediction } from 'replicate';
2+
import Replicate, { ApiError, Model, Prediction } from 'replicate';
33
import nock from 'nock';
44
import fetch from 'cross-fetch';
55

@@ -131,6 +131,30 @@ describe('Replicate client', () => {
131131
// Add more tests for error handling, edge cases, etc.
132132
});
133133

134+
describe('models.list', () => {
135+
test('Paginates results', async () => {
136+
nock(BASE_URL)
137+
.get('/models')
138+
.reply(200, {
139+
results: [{ url: 'https://replicate.com/some-user/model-1' }],
140+
next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
141+
})
142+
.get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw')
143+
.reply(200, {
144+
results: [{ url: 'https://replicate.com/some-user/model-2' }],
145+
next: null,
146+
});
147+
148+
const results: Model[] = [];
149+
for await (const batch of client.paginate(client.models.list)) {
150+
results.push(...batch);
151+
}
152+
expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]);
153+
154+
// Add more tests for error handling, edge cases, etc.
155+
});
156+
});
157+
134158
describe('predictions.create', () => {
135159
test('Calls the correct API route with the correct payload', async () => {
136160
nock(BASE_URL)

lib/models.js

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,28 @@ async function listModelVersions(model_owner, model_name) {
3737
* @returns {Promise<object>} Resolves with the model version data
3838
*/
3939
async function getModelVersion(model_owner, model_name, version_id) {
40-
const response = await this.request(
41-
`/models/${model_owner}/${model_name}/versions/${version_id}`,
42-
{
43-
method: 'GET',
44-
}
45-
);
40+
const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}`, {
41+
method: 'GET',
42+
});
43+
44+
return response.json();
45+
}
46+
47+
/**
48+
* List all public models
49+
*
50+
* @returns {Promise<object>} Resolves with the model version data
51+
*/
52+
async function listModels() {
53+
const response = await this.request('/models', {
54+
method: 'GET',
55+
});
4656

4757
return response.json();
4858
}
4959

5060
module.exports = {
5161
get: getModel,
62+
list: listModels,
5263
versions: { list: listModelVersions, get: getModelVersion },
5364
};

0 commit comments

Comments
 (0)