Skip to content

Commit bd6c1b7

Browse files
committed
Chain predictions off versions
This connects predictions to versions conceptually, to match their concept elsewhere, and gives us a clear API for differentiating from other kinds of predictions in the future. We create `Version` and `Prediction` classes we don't expect to expose externally. The entrypoint for users continues to be methods called from `ReplicateClient` directly.
1 parent ae83f0e commit bd6c1b7

File tree

8 files changed

+443
-271
lines changed

8 files changed

+443
-271
lines changed

README.md

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@ To run a prediction and return its output:
1515
```js
1616
import replicate from "replicate";
1717

18-
// Set your model's input parameters here
19-
const input = {
18+
const prediction = await replicate.version("<MODEL VERSION>").predict({
2019
prompt: "painting of a cat by andy warhol",
21-
};
22-
23-
const prediction = await replicate.predict("<MODEL VERSION>", input);
20+
});
2421

2522
console.log(prediction.output);
2623
// "https://replicate.delivery/pbxt/lGWovsQZ7jZuNtPvofMth1rSeCcVn5xes8dWWdWZ64MlTi7gA/out-0.png"
@@ -32,18 +29,34 @@ running, you can pass in an `onUpdate` callback function:
3229
```js
3330
import replicate from "replicate";
3431

35-
// Set your model's input parameters here
36-
const input = {
37-
prompt: "painting of a cat by andy warhol",
38-
};
39-
40-
await replicate.predict("<MODEL VERSION>", input, {
41-
onUpdate: (prediction) => {
42-
console.log(prediction.output);
32+
await replicate.version("<MODEL VERSION>").predict(
33+
{
34+
prompt: "painting of a cat by andy warhol",
4335
},
36+
{
37+
onUpdate: (prediction) => {
38+
console.log(prediction.output);
39+
},
40+
}
41+
);
42+
```
43+
44+
If you'd prefer to control your own polling you can use the low-level
45+
`createPrediction()` method:
46+
47+
```js
48+
import replicate from "replicate";
49+
50+
const prediction = await replicate.version("<MODEL VERSION>").createPrediction({
51+
prompt: "painting of a cat by andy warhol",
4452
});
53+
54+
console.log(prediction.status); // "starting"
4555
```
4656

57+
From there, you can fetch the current status of the prediction using
58+
`await prediction.load()` or `await replicate.prediction(prediction.id).load()`.
59+
4760
## License
4861

4962
[Apache 2.0](LICENSE)

lib/Prediction.js

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import { ReplicateError } from "./errors.js";
2+
import ReplicateObject from "./ReplicateObject.js";
3+
14
export const PredictionStatus = {
25
STARTING: "starting",
36
PROCESSING: "processing",
@@ -6,28 +9,19 @@ export const PredictionStatus = {
69
FAILED: "failed",
710
};
811

9-
export default class Prediction {
10-
static fromJSON(json) {
11-
let props;
12-
try {
13-
props = JSON.parse(json);
14-
} catch {
15-
throw new Error(`Unable to parse JSON: ${json}`);
16-
}
17-
18-
return new this(props);
19-
}
12+
export default class Prediction extends ReplicateObject {
13+
constructor({ id, ...rest }, client) {
14+
super(rest, client);
2015

21-
id;
22-
status;
16+
if (!id) {
17+
throw new ReplicateError("id is required");
18+
}
2319

24-
constructor({ id, status, ...rest }) {
2520
this.id = id;
26-
this.status = status;
21+
}
2722

28-
for (const key in rest) {
29-
this[key] = rest[key];
30-
}
23+
actionForGet() {
24+
return `GET /v1/predictions/${this.id}`;
3125
}
3226

3327
hasTerminalStatus() {

lib/Prediction.test.js

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import { jest } from "@jest/globals";
2+
import { FetchError } from "node-fetch";
3+
import Prediction, { PredictionStatus } from "./Prediction.js";
4+
5+
jest.unstable_mockModule("node-fetch", () => ({
6+
default: jest.fn(),
7+
FetchError,
8+
}));
9+
10+
const { default: ReplicateClient } = await import("./ReplicateClient.js");
11+
12+
let client;
13+
let prediction;
14+
15+
beforeEach(() => {
16+
process.env.REPLICATE_API_TOKEN = "test-token-from-env";
17+
18+
client = new ReplicateClient({});
19+
prediction = client.prediction("test-prediction");
20+
});
21+
22+
describe("load()", () => {
23+
it("makes request to get prediction", async () => {
24+
jest.spyOn(client, "request").mockResolvedValue({
25+
id: "test-prediction",
26+
status: PredictionStatus.SUCCEEDED,
27+
});
28+
29+
await prediction.load();
30+
31+
expect(client.request).toHaveBeenCalledWith(
32+
"GET /v1/predictions/test-prediction"
33+
);
34+
});
35+
36+
it("returns a Prediction", async () => {
37+
jest.spyOn(client, "request").mockResolvedValue({
38+
id: "test-prediction",
39+
status: PredictionStatus.SUCCEEDED,
40+
});
41+
42+
const returnedPrediction = await prediction.load();
43+
44+
expect(returnedPrediction).toBeInstanceOf(Prediction);
45+
});
46+
47+
it("updates the prediction in place", async () => {
48+
jest.spyOn(client, "request").mockResolvedValue({
49+
id: "test-prediction",
50+
status: PredictionStatus.SUCCEEDED,
51+
});
52+
53+
const returnedPrediction = await prediction.load();
54+
55+
expect(returnedPrediction).toBe(prediction);
56+
});
57+
});

lib/ReplicateClient.js

Lines changed: 5 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {
66
ReplicateResponseError,
77
} from "./errors.js";
88
import Prediction from "./Prediction.js";
9-
import { sleep } from "./utils.js";
9+
import Version from "./Version.js";
1010

1111
export default class ReplicateClient {
1212
baseURL;
@@ -21,79 +21,12 @@ export default class ReplicateClient {
2121
}
2222
}
2323

24-
// TODO: Optionally autocancel prediction on exception.
25-
async predict(
26-
version,
27-
input,
28-
{ onUpdate, onTemporaryError } = {},
29-
{
30-
defaultPollingInterval = 500,
31-
backoffFn = (errorCount) => Math.pow(2, errorCount) * 100,
32-
} = {}
33-
) {
34-
if (!version) {
35-
throw new ReplicateError("version is required");
36-
}
37-
38-
if (!input) {
39-
throw new ReplicateError("input is required");
40-
}
41-
42-
let prediction = await this.createPrediction(version, input);
43-
44-
onUpdate && onUpdate(prediction);
45-
46-
let pollingInterval = defaultPollingInterval;
47-
let errorCount = 0;
48-
49-
while (!prediction.hasTerminalStatus()) {
50-
await sleep(pollingInterval);
51-
pollingInterval = defaultPollingInterval; // Reset to default each time.
52-
53-
try {
54-
prediction = await this.getPrediction(prediction.id);
55-
56-
onUpdate && onUpdate(prediction);
57-
58-
errorCount = 0; // Reset because we've had a non-error response.
59-
} catch (err) {
60-
if (!err instanceof ReplicateResponseError) {
61-
throw err;
62-
}
63-
64-
if (
65-
!err.status ||
66-
(Math.floor(err.status / 100) !== 5 && err.status !== 429)
67-
) {
68-
throw err;
69-
}
70-
71-
errorCount += 1;
72-
73-
onTemporaryError && onTemporaryError(err);
74-
75-
pollingInterval = backoffFn(errorCount);
76-
}
77-
}
78-
79-
return prediction;
80-
}
81-
82-
async createPrediction(version, input) {
83-
const predictionData = await this.request("POST /v1/predictions", {
84-
version,
85-
input,
86-
});
87-
88-
return new Prediction(predictionData);
24+
version(id) {
25+
return new Version({ id }, this);
8926
}
9027

91-
async getPrediction(predictionID) {
92-
const predictionData = await this.request(
93-
`GET /v1/predictions/${predictionID}`
94-
);
95-
96-
return new Prediction(predictionData);
28+
prediction(id) {
29+
return new Prediction({ id }, this);
9730
}
9831

9932
async request(action, body) {

0 commit comments

Comments
 (0)