Skip to content

Commit 0a7ec78

Browse files
authored
Merge branch 'main' into mattt/deprecate-stream-parameter
2 parents 896c13a + eb9adfd commit 0a7ec78

File tree

10 files changed

+404
-24
lines changed

10 files changed

+404
-24
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
matrix:
1515
suite: [node]
1616
# See supported Node.js release schedule at https://nodejs.org/en/about/previous-releases
17-
node-version: [18.x, 20.x, 22.4] # TODO: unpin to 22.x once https://github.com/nodejs/node/issues/53902 is resolved
17+
node-version: [18.x, 20.x, 22.x]
1818

1919
steps:
2020
- uses: actions/checkout@v4

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,18 @@ const response = await replicate.models.list();
455455
}
456456
```
457457

458+
### `replicate.models.search`
459+
460+
Search for public models on Replicate.
461+
462+
```js
463+
const response = await replicate.models.search(query);
464+
```
465+
466+
| name | type | description |
467+
| ------- | ------ | -------------------------------------- |
468+
| `query` | string | **Required**. The search query string. |
469+
458470
### `replicate.models.create`
459471

460472
Create a new public or private model.

index.d.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ declare module "replicate" {
88
response: Response;
99
}
1010

11+
export interface FileOutput extends ReadableStream {
12+
blob(): Promise<Blob>;
13+
url(): URL;
14+
toString(): string;
15+
}
16+
1117
export interface Account {
1218
type: "user" | "organization";
1319
username: string;
@@ -137,6 +143,7 @@ declare module "replicate" {
137143
init?: RequestInit
138144
) => Promise<Response>;
139145
fileEncodingStrategy?: FileEncodingStrategy;
146+
useFileOutput?: boolean;
140147
});
141148

142149
auth: string;
@@ -281,6 +288,7 @@ declare module "replicate" {
281288
version_id: string
282289
): Promise<ModelVersion>;
283290
};
291+
search(query: string): Promise<Page<Model>>;
284292
};
285293

286294
predictions: {

index.js

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
3-
const { createReadableStream } = require("./lib/stream");
3+
const { createReadableStream, createFileOutput } = require("./lib/stream");
44
const {
5+
transform,
56
withAutomaticRetries,
67
validateWebhook,
78
parseProgressFromLogs,
@@ -47,6 +48,7 @@ class Replicate {
4748
* @param {string} options.userAgent - Identifier of your app
4849
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
4950
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
51+
* @param {boolean} [options.useFileOutput] - Set to `true` to return `FileOutput` objects from `run` instead of URLs, defaults to false.
5052
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
5153
*/
5254
constructor(options = {}) {
@@ -57,7 +59,8 @@ class Replicate {
5759
options.userAgent || `replicate-javascript/${packageJSON.version}`;
5860
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
5961
this.fetch = options.fetch || globalThis.fetch;
60-
this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default";
62+
this.fileEncodingStrategy = options.fileEncodingStrategy || "default";
63+
this.useFileOutput = options.useFileOutput || false;
6164

6265
this.accounts = {
6366
current: accounts.current.bind(this),
@@ -98,6 +101,7 @@ class Replicate {
98101
list: models.versions.list.bind(this),
99102
get: models.versions.get.bind(this),
100103
},
104+
search: models.search.bind(this),
101105
};
102106

103107
this.predictions = {
@@ -195,7 +199,17 @@ class Replicate {
195199
throw new Error(`Prediction failed: ${prediction.error}`);
196200
}
197201

198-
return prediction.output;
202+
return transform(prediction.output, (value) => {
203+
if (
204+
typeof value === "string" &&
205+
(value.startsWith("https:") || value.startsWith("data:"))
206+
) {
207+
return this.useFileOutput
208+
? createFileOutput({ url: value, fetch: this.fetch })
209+
: value;
210+
}
211+
return value;
212+
});
199213
}
200214

201215
/**

index.test.ts

Lines changed: 238 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import { expect, jest, test } from "@jest/globals";
22
import Replicate, {
33
ApiError,
4+
FileOutput,
45
Model,
56
Prediction,
67
validateWebhook,
78
parseProgressFromLogs,
89
} from "replicate";
910
import nock from "nock";
10-
import { Readable } from "node:stream";
1111
import { createReadableStream } from "./lib/stream";
1212

1313
let client: Replicate;
@@ -1053,7 +1053,7 @@ describe("Replicate client", () => {
10531053
describe("predictions.create with model", () => {
10541054
test("Calls the correct API route with the correct payload", async () => {
10551055
nock(BASE_URL)
1056-
.post("/models/meta/llama-2-70b-chat/predictions")
1056+
.post("/models/meta/meta-llama-3-70b-instruct/predictions")
10571057
.reply(200, {
10581058
id: "heat2o3bzn3ahtr6bjfftvbaci",
10591059
model: "replicate/lifeboat-70b",
@@ -1072,7 +1072,7 @@ describe("Replicate client", () => {
10721072
},
10731073
});
10741074
const prediction = await client.predictions.create({
1075-
model: "meta/llama-2-70b-chat",
1075+
model: "meta/meta-llama-3-70b-instruct",
10761076
input: {
10771077
prompt: "Please write a haiku about llamas.",
10781078
},
@@ -1217,6 +1217,44 @@ describe("Replicate client", () => {
12171217
});
12181218
});
12191219

1220+
describe("models.search", () => {
1221+
test("Calls the correct API route with the correct payload", async () => {
1222+
nock(BASE_URL)
1223+
.intercept("/models", "QUERY")
1224+
.reply(200, {
1225+
results: [
1226+
{
1227+
url: "https://replicate.com/meta/meta-llama-3-70b-instruct",
1228+
owner: "meta",
1229+
name: "meta-llama-3-70b-instruct",
1230+
description:
1231+
"Llama 2 is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters.",
1232+
visibility: "public",
1233+
github_url: null,
1234+
paper_url:
1235+
"https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/",
1236+
license_url: "https://ai.meta.com/llama/license/",
1237+
run_count: 1000000,
1238+
cover_image_url:
1239+
"https://replicate.delivery/pbxt/IJqFrnAKEDiCBnlXyndzVVxkZvfQ7kLjGVEZZPXTRXxOOPkQA/llama2.png",
1240+
default_example: null,
1241+
latest_version: null,
1242+
},
1243+
// ... more results ...
1244+
],
1245+
next: null,
1246+
previous: null,
1247+
});
1248+
1249+
const searchResults = await client.models.search("llama");
1250+
expect(searchResults.results.length).toBeGreaterThan(0);
1251+
expect(searchResults.results[0].owner).toBe("meta");
1252+
expect(searchResults.results[0].name).toBe("meta-llama-3-70b-instruct");
1253+
});
1254+
1255+
// Add more tests for error handling, edge cases, etc.
1256+
});
1257+
12201258
describe("run", () => {
12211259
test("Calls the correct API routes", async () => {
12221260
nock(BASE_URL)
@@ -1504,6 +1542,203 @@ describe("Replicate client", () => {
15041542

15051543
scope.done();
15061544
});
1545+
1546+
test("returns FileOutput for URLs when useFileOutput is true", async () => {
1547+
client = new Replicate({ auth: "foo", useFileOutput: true });
1548+
1549+
nock(BASE_URL)
1550+
.post("/predictions")
1551+
.reply(201, {
1552+
id: "ufawqhfynnddngldkgtslldrkq",
1553+
status: "starting",
1554+
logs: null,
1555+
})
1556+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1557+
.reply(200, {
1558+
id: "ufawqhfynnddngldkgtslldrkq",
1559+
status: "processing",
1560+
logs: [].join("\n"),
1561+
})
1562+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1563+
.reply(200, {
1564+
id: "ufawqhfynnddngldkgtslldrkq",
1565+
status: "processing",
1566+
logs: [].join("\n"),
1567+
})
1568+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1569+
.reply(200, {
1570+
id: "ufawqhfynnddngldkgtslldrkq",
1571+
status: "succeeded",
1572+
output: "https://example.com",
1573+
logs: [].join("\n"),
1574+
});
1575+
1576+
nock("https://example.com")
1577+
.get("/")
1578+
.reply(200, "hello world", { "Content-Type": "text/plain" });
1579+
1580+
const output = (await client.run(
1581+
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1582+
{
1583+
input: { text: "Hello, world!" },
1584+
}
1585+
)) as FileOutput;
1586+
1587+
expect(output).toBeInstanceOf(ReadableStream);
1588+
expect(output.url()).toEqual(new URL("https://example.com"));
1589+
1590+
const blob = await output.blob();
1591+
expect(blob.type).toEqual("text/plain");
1592+
expect(blob.arrayBuffer()).toEqual(
1593+
new Blob(["Hello, world!"]).arrayBuffer()
1594+
);
1595+
});
1596+
1597+
test("returns FileOutput for URLs when useFileOutput is true - acts like string", async () => {
1598+
client = new Replicate({ auth: "foo", useFileOutput: true });
1599+
1600+
nock(BASE_URL)
1601+
.post("/predictions")
1602+
.reply(201, {
1603+
id: "ufawqhfynnddngldkgtslldrkq",
1604+
status: "starting",
1605+
logs: null,
1606+
})
1607+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1608+
.reply(200, {
1609+
id: "ufawqhfynnddngldkgtslldrkq",
1610+
status: "processing",
1611+
logs: [].join("\n"),
1612+
})
1613+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1614+
.reply(200, {
1615+
id: "ufawqhfynnddngldkgtslldrkq",
1616+
status: "processing",
1617+
logs: [].join("\n"),
1618+
})
1619+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1620+
.reply(200, {
1621+
id: "ufawqhfynnddngldkgtslldrkq",
1622+
status: "succeeded",
1623+
output: "https://example.com",
1624+
logs: [].join("\n"),
1625+
});
1626+
1627+
nock("https://example.com")
1628+
.get("/")
1629+
.reply(200, "hello world", { "Content-Type": "text/plain" });
1630+
1631+
const output = (await client.run(
1632+
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1633+
{
1634+
input: { text: "Hello, world!" },
1635+
}
1636+
)) as unknown as string;
1637+
1638+
expect(fetch(output).then((r) => r.text())).resolves.toEqual(
1639+
"hello world"
1640+
);
1641+
});
1642+
1643+
test("returns FileOutput for URLs when useFileOutput is true - array output", async () => {
1644+
client = new Replicate({ auth: "foo", useFileOutput: true });
1645+
1646+
nock(BASE_URL)
1647+
.post("/predictions")
1648+
.reply(201, {
1649+
id: "ufawqhfynnddngldkgtslldrkq",
1650+
status: "starting",
1651+
logs: null,
1652+
})
1653+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1654+
.reply(200, {
1655+
id: "ufawqhfynnddngldkgtslldrkq",
1656+
status: "processing",
1657+
logs: [].join("\n"),
1658+
})
1659+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1660+
.reply(200, {
1661+
id: "ufawqhfynnddngldkgtslldrkq",
1662+
status: "processing",
1663+
logs: [].join("\n"),
1664+
})
1665+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1666+
.reply(200, {
1667+
id: "ufawqhfynnddngldkgtslldrkq",
1668+
status: "succeeded",
1669+
output: ["https://example.com"],
1670+
logs: [].join("\n"),
1671+
});
1672+
1673+
nock("https://example.com")
1674+
.get("/")
1675+
.reply(200, "hello world", { "Content-Type": "text/plain" });
1676+
1677+
const [output] = (await client.run(
1678+
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1679+
{
1680+
input: { text: "Hello, world!" },
1681+
}
1682+
)) as FileOutput[];
1683+
1684+
expect(output).toBeInstanceOf(ReadableStream);
1685+
expect(output.url()).toEqual(new URL("https://example.com"));
1686+
1687+
const blob = await output.blob();
1688+
expect(blob.type).toEqual("text/plain");
1689+
expect(blob.arrayBuffer()).toEqual(
1690+
new Blob(["Hello, world!"]).arrayBuffer()
1691+
);
1692+
});
1693+
1694+
test("returns FileOutput for URLs when useFileOutput is true - data uri", async () => {
1695+
client = new Replicate({ auth: "foo", useFileOutput: true });
1696+
1697+
nock(BASE_URL)
1698+
.post("/predictions")
1699+
.reply(201, {
1700+
id: "ufawqhfynnddngldkgtslldrkq",
1701+
status: "starting",
1702+
logs: null,
1703+
})
1704+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1705+
.reply(200, {
1706+
id: "ufawqhfynnddngldkgtslldrkq",
1707+
status: "processing",
1708+
logs: [].join("\n"),
1709+
})
1710+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1711+
.reply(200, {
1712+
id: "ufawqhfynnddngldkgtslldrkq",
1713+
status: "processing",
1714+
logs: [].join("\n"),
1715+
})
1716+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
1717+
.reply(200, {
1718+
id: "ufawqhfynnddngldkgtslldrkq",
1719+
status: "succeeded",
1720+
output: "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==",
1721+
logs: [].join("\n"),
1722+
});
1723+
1724+
const output = (await client.run(
1725+
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1726+
{
1727+
input: { text: "Hello, world!" },
1728+
}
1729+
)) as FileOutput;
1730+
1731+
expect(output).toBeInstanceOf(ReadableStream);
1732+
expect(output.url()).toEqual(
1733+
new URL("data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==")
1734+
);
1735+
1736+
const blob = await output.blob();
1737+
expect(blob.type).toEqual("text/plain");
1738+
expect(blob.arrayBuffer()).toEqual(
1739+
new Blob(["Hello, world!"]).arrayBuffer()
1740+
);
1741+
});
15071742
});
15081743

15091744
describe("webhooks.default.secret.get", () => {

0 commit comments

Comments
 (0)