Skip to content

Commit d549814

Browse files
authored
Allow run method to take model argument, when supported (#167)
* Extract model version identifier into separate component * Allow `run` method to take model argument, when supported
1 parent 89d88a0 commit d549814

File tree

4 files changed

+113
-28
lines changed

4 files changed

+113
-28
lines changed

index.d.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ declare module 'replicate' {
8989
fetch: Function;
9090

9191
run(
92-
identifier: `${string}/${string}:${string}`,
92+
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
9393
options: {
9494
input: object;
9595
wait?: { interval?: number };

index.js

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
const ApiError = require('./lib/error');
2+
const ModelVersionIdentifier = require('./lib/identifier');
23
const { withAutomaticRetries } = require('./lib/util');
34

45
const collections = require('./lib/collections');
@@ -91,7 +92,7 @@ class Replicate {
9192
/**
9293
* Run a model and wait for its output.
9394
*
94-
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
95+
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
9596
* @param {object} options
9697
* @param {object} options.input - Required. An object with the model inputs
9798
* @param {object} [options.wait] - Options for waiting for the prediction to finish
@@ -100,37 +101,29 @@ class Replicate {
100101
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
101102
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
102103
* @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed.
104+
* @throws {Error} If the reference is invalid
103105
* @throws {Error} If the prediction failed
104106
* @returns {Promise<object>} - Resolves with the output of running the model
105107
*/
106-
async run(identifier, options, progress) {
108+
async run(ref, options, progress) {
107109
const { wait, ...data } = options;
108110

109-
// Define a pattern for owner and model names that allows
110-
// letters, digits, and certain special characters.
111-
// Example: "user123", "abc__123", "user.name"
112-
const namePattern = /[a-zA-Z0-9]+(?:(?:[._]|__|[-]*)[a-zA-Z0-9]+)*/;
113-
114-
// Define a pattern for "owner/name:version" format with named capturing groups.
115-
// Example: "user123/repo_a:1a2b3c"
116-
const pattern = new RegExp(
117-
`^(?<owner>${namePattern.source})/(?<name>${namePattern.source}):(?<version>[0-9a-fA-F]+)$`
118-
);
119-
120-
const match = identifier.match(pattern);
121-
if (!match || !match.groups) {
122-
throw new Error(
123-
'Invalid version. It must be in the format "owner/name:version"'
111+
const identifier = ModelVersionIdentifier.parse(ref);
112+
113+
let prediction;
114+
if (identifier.version) {
115+
prediction = await this.predictions.create({
116+
...data,
117+
version: identifier.version,
118+
});
119+
} else {
120+
prediction = await this.models.predictions.create(
121+
identifier.owner,
122+
identifier.name,
123+
data
124124
);
125125
}
126126

127-
const { version } = match.groups;
128-
129-
let prediction = await this.predictions.create({
130-
...data,
131-
version,
132-
});
133-
134127
// Call progress callback with the initial prediction object
135128
if (progress) {
136129
progress(prediction);

index.test.ts

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ describe('Replicate client', () => {
749749
});
750750

751751
describe('run', () => {
752-
test('Calls the correct API routes', async () => {
752+
test('Calls the correct API routes for a version', async () => {
753753
let firstPollingRequest = true;
754754

755755
nock(BASE_URL)
@@ -808,6 +808,65 @@ describe('Replicate client', () => {
808808
expect(progress).toHaveBeenCalledTimes(4);
809809
});
810810

811+
test('Calls the correct API routes for a model', async () => {
812+
let firstPollingRequest = true;
813+
814+
nock(BASE_URL)
815+
.post('/models/replicate/hello-world/predictions')
816+
.reply(201, {
817+
id: 'ufawqhfynnddngldkgtslldrkq',
818+
status: 'starting',
819+
})
820+
.get('/predictions/ufawqhfynnddngldkgtslldrkq')
821+
.twice()
822+
.reply(200, {
823+
id: 'ufawqhfynnddngldkgtslldrkq',
824+
status: 'processing',
825+
})
826+
.get('/predictions/ufawqhfynnddngldkgtslldrkq')
827+
.reply(200, {
828+
id: 'ufawqhfynnddngldkgtslldrkq',
829+
status: 'succeeded',
830+
output: 'Goodbye!',
831+
});
832+
833+
const progress = jest.fn();
834+
835+
const output = await client.run(
836+
'replicate/hello-world',
837+
{
838+
input: { text: 'Hello, world!' },
839+
wait: { interval: 1 }
840+
},
841+
progress
842+
);
843+
844+
expect(output).toBe('Goodbye!');
845+
846+
expect(progress).toHaveBeenNthCalledWith(1, {
847+
id: 'ufawqhfynnddngldkgtslldrkq',
848+
status: 'starting',
849+
});
850+
851+
expect(progress).toHaveBeenNthCalledWith(2, {
852+
id: 'ufawqhfynnddngldkgtslldrkq',
853+
status: 'processing',
854+
});
855+
856+
expect(progress).toHaveBeenNthCalledWith(3, {
857+
id: 'ufawqhfynnddngldkgtslldrkq',
858+
status: 'processing',
859+
});
860+
861+
expect(progress).toHaveBeenNthCalledWith(4, {
862+
id: 'ufawqhfynnddngldkgtslldrkq',
863+
status: 'succeeded',
864+
output: 'Goodbye!',
865+
});
866+
867+
expect(progress).toHaveBeenCalledTimes(4);
868+
});
869+
811870
test('Does not throw an error for identifier containing hyphen and full stop', async () => {
812871
nock(BASE_URL)
813872
.post('/predictions')
@@ -828,8 +887,6 @@ describe('Replicate client', () => {
828887
test('Throws an error for invalid identifiers', async () => {
829888
const options = { input: { text: 'Hello, world!' } }
830889

831-
await expect(client.run('owner/model:invalid', options)).rejects.toThrow();
832-
833890
// @ts-expect-error
834891
await expect(client.run('owner:abc123', options)).rejects.toThrow();
835892

lib/identifier.js

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* A reference to a model version in the format `owner/name` or `owner/name:version`.
3+
*/
4+
class ModelVersionIdentifier {
5+
/*
6+
* @param {string} Required. The model owner.
7+
* @param {string} Required. The model name.
8+
* @param {string} The model version.
9+
*/
10+
constructor(owner, name, version = null) {
11+
this.owner = owner;
12+
this.name = name;
13+
this.version = version;
14+
}
15+
16+
/*
17+
* Parse a reference to a model version
18+
*
19+
* @param {string}
20+
* @returns {ModelVersionIdentifier}
21+
* @throws {Error} If the reference is invalid.
22+
*/
23+
static parse(ref) {
24+
const match = ref.match(/^(?<owner>[^/]+)\/(?<name>[^/:]+)(:(?<version>.+))?$/);
25+
if (!match) {
26+
throw new Error(`Invalid reference to model version: ${ref}. Expected format: owner/name or owner/name:version`);
27+
}
28+
29+
const { owner, name, version } = match.groups;
30+
31+
return new ModelVersionIdentifier(owner, name, version);
32+
}
33+
}
34+
35+
module.exports = ModelVersionIdentifier;

0 commit comments

Comments
 (0)