Skip to content

Commit c886dd5

Browse files
committed
Implement provider logic
1 parent 7be9b60 commit c886dd5

File tree

8 files changed

+175
-25
lines changed

8 files changed

+175
-25
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { REPLICATE_API_BASE_URL } from "../providers/replicate";
44
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
55
import { TOGETHER_API_BASE_URL } from "../providers/together";
66
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
7-
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/blackforestlabs-ai";
7+
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
88
import type { InferenceProvider } from "../types";
99
import type { InferenceTask, Options, RequestArgs } from "../types";
1010
import { isUrl } from "./isUrl";
@@ -68,26 +68,31 @@ export async function makeRequestOptions(
6868
? "hf-token"
6969
: "provider-key"
7070
: includeCredentials === "include"
71-
? "credentials-include"
72-
: "none";
71+
? "credentials-include"
72+
: "none";
7373

7474
const url = endpointUrl
7575
? chatCompletion
7676
? endpointUrl + `/v1/chat/completions`
7777
: endpointUrl
7878
: makeUrl({
79-
authMethod,
80-
chatCompletion: chatCompletion ?? false,
81-
forceTask,
82-
model,
83-
provider: provider ?? "hf-inference",
84-
taskHint,
85-
});
79+
authMethod,
80+
chatCompletion: chatCompletion ?? false,
81+
forceTask,
82+
model,
83+
provider: provider ?? "hf-inference",
84+
taskHint,
85+
});
8686

8787
const headers: Record<string, string> = {};
8888
if (accessToken) {
89-
headers["Authorization"] =
90-
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
89+
if (provider === "fal-ai" && authMethod === "provider-key") {
90+
headers["Authorization"] = `Key ${accessToken}`;
91+
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
92+
headers["X-Key"] = accessToken;
93+
} else {
94+
headers["Authorization"] = `Bearer ${accessToken}`;
95+
}
9196
}
9297

9398
// e.g. @huggingface/inference/3.1.3
@@ -143,9 +148,9 @@ export async function makeRequestOptions(
143148
body: binary
144149
? args.data
145150
: JSON.stringify({
146-
...otherArgs,
147-
...(chatCompletion || provider === "together" ? { model } : undefined),
148-
}),
151+
...otherArgs,
152+
...(chatCompletion || provider === "together" ? { model } : undefined),
153+
}),
149154
...(credentials ? { credentials } : undefined),
150155
signal: options?.signal,
151156
};
@@ -167,6 +172,12 @@ function makeUrl(params: {
167172

168173
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
169174
switch (params.provider) {
175+
case "black-forest-labs": {
176+
const baseUrl = shouldProxy
177+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
178+
: BLACKFORESTLABS_AI_API_BASE_URL;
179+
return `${baseUrl}/${params.model}`
180+
}
170181
case "fal-ai": {
171182
const baseUrl = shouldProxy
172183
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)

packages/inference/src/providers/blackforestlabs-ai.ts renamed to packages/inference/src/providers/black-forest-labs.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1/";
1+
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
22

33
/**
44
* See the registered mapping of HF model ID => Black Forest Labs model ID here:

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
1616
* Example:
1717
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1818
*/
19+
"black-forest-labs": {},
1920
"fal-ai": {},
2021
"fireworks-ai": {},
2122
"hf-inference": {},

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
33
import type { BaseArgs, Options } from "../../types";
44
import { omit } from "../../utils/omit";
55
import { request } from "../custom/request";
6+
import { delay } from "../../utils/delay";
7+
import { randomUUID } from "crypto";
68

79
export type TextToImageArgs = BaseArgs & TextToImageInput;
810

@@ -14,26 +16,33 @@ interface Base64ImageGeneration {
1416
interface OutputUrlImageGeneration {
1517
output: string[];
1618
}
19+
interface BlackForestLabsResponse {
20+
id: string;
21+
polling_url: string;
22+
}
1723

1824
/**
1925
* This task reads some text input and outputs an image.
2026
* Recommended model: stabilityai/stable-diffusion-2
2127
*/
2228
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<Blob> {
2329
const payload =
24-
args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
30+
args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate" || args.provider === "black-forest-labs"
2531
? {
26-
...omit(args, ["inputs", "parameters"]),
27-
...args.parameters,
28-
...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
29-
prompt: args.inputs,
30-
}
32+
...omit(args, ["inputs", "parameters"]),
33+
...args.parameters,
34+
...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
35+
prompt: args.inputs,
36+
}
3137
: args;
32-
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
38+
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse>(payload, {
3339
...options,
3440
taskHint: "text-to-image",
3541
});
3642
if (res && typeof res === "object") {
43+
if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
44+
return await pollBflResponse(res.polling_url);
45+
}
3746
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
3847
const image = await fetch(res.images[0].url);
3948
return await image.blob();
@@ -56,3 +65,23 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
5665
}
5766
return res;
5867
}
68+
69+
async function pollBflResponse(url: string): Promise<Blob> {
70+
const urlObj = new URL(url);
71+
for (let step = 0; step < 5; step++) {
72+
await delay(1000);
73+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
74+
urlObj.searchParams.set("uuid", randomUUID());
75+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
76+
if (!resp.ok) {
77+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
78+
}
79+
const payload = await resp.json();
80+
if (typeof payload === "object" && payload && "status" in payload && typeof payload.status === "string" && payload.status === "Ready" && "result" in payload && typeof payload.result === "object" && payload.result && "sample" in payload.result && typeof payload.result.sample === "string") {
81+
const image = await fetch(payload.result.sample);
82+
return await image.blob();
83+
}
84+
}
85+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
86+
}
87+
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
export function delay(ms: number): Promise<void> {
2+
return new Promise(resolve => {
3+
setTimeout(() => resolve(), ms);
4+
});
5+
}

packages/inference/test/HfInference.spec.ts

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { assert, describe, expect, it } from "vitest";
22

33
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
44

5-
import { chatCompletion, HfInference } from "../src";
5+
import { chatCompletion, HfInference, textToImage } from "../src";
66
import { textToVideo } from "../src/tasks/cv/textToVideo";
77
import { readTestFile } from "./test-files";
88
import "./vcr";
@@ -1125,4 +1125,33 @@ describe.concurrent("HfInference", () => {
11251125
},
11261126
TIMEOUT
11271127
);
1128+
1129+
describe.concurrent(
1130+
"Black Forest Labs",
1131+
() => {
1132+
1133+
HARDCODED_MODEL_ID_MAPPING["black-forest-labs"] = {
1134+
"black-forest-labs/FLUX.1-dev": "flux-dev",
1135+
// "black-forest-labs/FLUX.1-schnell": "flux-pro",
1136+
};
1137+
1138+
it("textToImage", async () => {
1139+
const res = await textToImage({
1140+
model: "black-forest-labs/FLUX.1-dev",
1141+
provider: "black-forest-labs",
1142+
accessToken: env.HF_BLACK_FOREST_LABS_KEY,
1143+
inputs: "A raccoon driving a truck",
1144+
parameters: {
1145+
height: 256,
1146+
width: 256,
1147+
num_inference_steps: 4,
1148+
seed: 8817,
1149+
}
1150+
});
1151+
expect(res).toBeInstanceOf(Blob);
1152+
});
1153+
1154+
},
1155+
TIMEOUT
1156+
);
11281157
});

packages/inference/test/tapes.json

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6920,5 +6920,80 @@
69206920
"vary": "Accept-Encoding"
69216921
}
69226922
}
6923+
},
6924+
"31cf35b760a4e1567bcaa8094b425854e35e1285aa014263253a136875197665": {
6925+
"url": "https://api.us1.bfl.ai/v1/flux-dev",
6926+
"init": {
6927+
"headers": {
6928+
"Content-Type": "application/json"
6929+
},
6930+
"method": "POST",
6931+
"body": "{\"height\":256,\"width\":256,\"num_inference_steps\":4,\"seed\":8817,\"response_format\":\"base64\",\"prompt\":\"A raccoon driving a truck\"}"
6932+
},
6933+
"response": {
6934+
"body": "{\"id\":\"a00fbfe4-1abd-4d54-966a-71cb1c3c3700\",\"polling_url\":\"https://api.us1.bfl.ai/v1/get_result?id=a00fbfe4-1abd-4d54-966a-71cb1c3c3700\"}",
6935+
"status": 200,
6936+
"statusText": "OK",
6937+
"headers": {
6938+
"connection": "keep-alive",
6939+
"content-type": "application/json",
6940+
"strict-transport-security": "max-age=31536000; includeSubDomains"
6941+
}
6942+
}
6943+
},
6944+
"bc182ea4d907a4b1fe558f2b238ac5252f7b9b4768c9a398faae8e2b89177d49": {
6945+
"url": "https://api.us1.bfl.ai/v1/get_result?id=a00fbfe4-1abd-4d54-966a-71cb1c3c3700&uuid=87c533f5-b483-4105-8bc8-4aacd2831cc0",
6946+
"init": {
6947+
"headers": {
6948+
"Content-Type": "application/json"
6949+
}
6950+
},
6951+
"response": {
6952+
"body": "{\"id\":\"a00fbfe4-1abd-4d54-966a-71cb1c3c3700\",\"status\":\"Pending\",\"result\":null,\"progress\":0.7}",
6953+
"status": 200,
6954+
"statusText": "OK",
6955+
"headers": {
6956+
"connection": "keep-alive",
6957+
"content-type": "application/json",
6958+
"retry-after": "1",
6959+
"strict-transport-security": "max-age=31536000; includeSubDomains"
6960+
}
6961+
}
6962+
},
6963+
"e44536068e0d85fea147002c94e3a7c62b82dd8c8314f69a4ae777b49056b657": {
6964+
"url": "https://api.us1.bfl.ai/v1/get_result?id=a00fbfe4-1abd-4d54-966a-71cb1c3c3700&uuid=b0e31b03-3229-4ed5-a90f-066b718bded0",
6965+
"init": {
6966+
"headers": {
6967+
"Content-Type": "application/json"
6968+
}
6969+
},
6970+
"response": {
6971+
"body": "{\"id\":\"a00fbfe4-1abd-4d54-966a-71cb1c3c3700\",\"status\":\"Ready\",\"result\":{\"sample\":\"https://delivery-us1.bfl.ai/results/ff9ee696c18a4faeb09835b997ed711b/sample.jpeg?se=2025-02-12T14%3A32%3A39Z&sp=r&sv=2024-11-04&sr=b&rsct=image/jpeg&sig=zJ%2BxSXMHsDd9Jn/r0FONvWSzKXYtrFk0SLLf8zrhXO0%3D\",\"prompt\":\"A raccoon driving a truck\",\"seed\":8817,\"start_time\":1739370157.0408669,\"end_time\":1739370159.3932223,\"duration\":2.352355480194092},\"progress\":null}",
6972+
"status": 200,
6973+
"statusText": "OK",
6974+
"headers": {
6975+
"connection": "keep-alive",
6976+
"content-type": "application/json",
6977+
"retry-after": "1",
6978+
"strict-transport-security": "max-age=31536000; includeSubDomains"
6979+
}
6980+
}
6981+
},
6982+
"2bec44b6259bd25573ac1008a7d017f30bc541fc55de3906e882a46ef7f39e0e": {
6983+
"url": "https://delivery-us1.bfl.ai/results/ff9ee696c18a4faeb09835b997ed711b/sample.jpeg?se=2025-02-12T14%3A32%3A39Z&sp=r&sv=2024-11-04&sr=b&rsct=image/jpeg&sig=zJ%2BxSXMHsDd9Jn/r0FONvWSzKXYtrFk0SLLf8zrhXO0%3D",
6984+
"init": {},
6985+
"response": {
6986+
"body": "",
6987+
"status": 200,
6988+
"statusText": "OK",
6989+
"headers": {
6990+
"accept-ranges": "bytes",
6991+
"connection": "keep-alive",
6992+
"content-md5": "NpZjwRHVW54BWXabyXtu+A==",
6993+
"content-type": "image/jpeg",
6994+
"etag": "\"0x8DD4B70B47817DA\"",
6995+
"last-modified": "Wed, 12 Feb 2025 14:22:39 GMT"
6996+
}
6997+
}
69236998
}
69246999
}

packages/inference/test/vcr.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async function vcr(
181181
const tape: Tape = {
182182
url,
183183
init: {
184-
headers: init.headers && omit(init.headers as Record<string, string>, ["Authorization", "User-Agent"]),
184+
headers: init.headers && omit(init.headers as Record<string, string>, ["Authorization", "User-Agent", "X-Key"]),
185185
method: init.method,
186186
body: typeof init.body === "string" && init.body.length < 1_000 ? init.body : undefined,
187187
},

0 commit comments

Comments
 (0)