diff --git a/index.d.ts b/index.d.ts index a0cba1d..fee693d 100644 --- a/index.d.ts +++ b/index.d.ts @@ -19,6 +19,7 @@ declare module "replicate" { username: string; name: string; github_url?: string; + avatar_url?: string; } export interface Collection { @@ -48,11 +49,11 @@ declare module "replicate" { export interface FileObject { id: string; - name: string; content_type: string; size: number; - etag: string; - checksum: string; + checksums: { + sha256: string; + }; metadata: Record; created_at: string; expires_at: string | null; @@ -85,22 +86,26 @@ declare module "replicate" { export interface ModelVersion { id: string; created_at: string; - cog_version: string; - openapi_schema: object; + cog_version: string | null; + openapi_schema: object | null; } export interface Prediction { id: string; status: Status; model: string; - version: string; + version: string | "hidden"; input: object; output?: any; // TODO: this should be `unknown` source: "api" | "web"; error?: unknown; logs?: string; + data_removed: boolean; + deadline?: string; + deployment?: string; metrics?: { predict_time?: number; + total_time?: number; }; webhook?: string; webhook_events_filter?: WebhookEventType[]; @@ -111,10 +116,38 @@ declare module "replicate" { get: string; cancel: string; stream?: string; + web?: string; }; } - export type Training = Prediction; + export interface Training { + id: string; + status: Status; + model: string; + version: string; + input: object; + output?: { + version?: string; + weights?: string; + }; + source: "api" | "web"; + error?: unknown; + logs?: string; + metrics?: { + predict_time?: number; + total_time?: number; + }; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + created_at: string; + started_at?: string; + completed_at?: string; + urls: { + get: string; + cancel: string; + web?: string; + }; + } export type FileEncodingStrategy = "default" | "upload" | "data-uri"; diff --git a/index.test.ts b/index.test.ts index 4905908..df277bf 100644 --- a/index.test.ts +++ b/index.test.ts @@ -4,6 +4,7 @@ import Replicate, { FileOutput, Model, Prediction, + Training, validateWebhook, parseProgressFromLogs, } from "replicate"; @@ -906,7 +907,7 @@ describe("Replicate client", () => { next: null, }); - const results: Prediction[] = []; + const results: Training[] = []; for await (const batch of client.paginate(client.trainings.list)) { results.push(...batch); } @@ -1176,11 +1177,11 @@ describe("Replicate client", () => { .post("/files") .reply(200, { id: "123", - name: "test-file", content_type: "application/octet-stream", size: 1024, - etag: "abc123", - checksum: "sha256:1234567890abcdef", + checksums: { + sha256: "1234567890abcdef", + }, metadata: {}, created_at: "2023-01-01T00:00:00Z", expires_at: null, @@ -1190,7 +1191,6 @@ describe("Replicate client", () => { }); const file = await client.files.create(testCase.value); expect(file.id).toBe("123"); - expect(file.name).toBe("test-file"); } }); }); @@ -1201,11 +1201,11 @@ describe("Replicate client", () => { .get("/files/123") .reply(200, { id: "123", - name: "test-file", content_type: "application/octet-stream", size: 1024, - etag: "abc123", - checksum: "sha256:1234567890abcdef", + checksums: { + sha256: "1234567890abcdef", + }, metadata: {}, created_at: "2023-01-01T00:00:00Z", expires_at: null, @@ -1216,7 +1216,6 @@ describe("Replicate client", () => { const file = await client.files.get("123"); expect(file.id).toBe("123"); - expect(file.name).toBe("test-file"); }); }); @@ -1230,11 +1229,11 @@ describe("Replicate client", () => { results: [ { id: "123", - name: "test-file", content_type: "application/octet-stream", size: 1024, - etag: "abc123", - checksum: "sha256:1234567890abcdef", + checksums: { + sha256: "1234567890abcdef", + }, metadata: {}, created_at: "2023-01-01T00:00:00Z", expires_at: null,