@@ -3,6 +3,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
33import type { BaseArgs , Options } from "../../types" ;
44import { omit } from "../../utils/omit" ;
55import { request } from "../custom/request" ;
6+ import { delay } from "../../utils/delay" ;
7+ import { randomUUID } from "crypto" ;
68
79export type TextToImageArgs = BaseArgs & TextToImageInput ;
810
@@ -14,26 +16,33 @@ interface Base64ImageGeneration {
1416interface OutputUrlImageGeneration {
1517 output : string [ ] ;
1618}
19+ interface BlackForestLabsResponse {
20+ id : string ;
21+ polling_url : string ;
22+ }
1723
1824/**
1925 * This task reads some text input and outputs an image.
2026 * Recommended model: stabilityai/stable-diffusion-2
2127 */
2228export async function textToImage ( args : TextToImageArgs , options ?: Options ) : Promise < Blob > {
2329 const payload =
24- args . provider === "together" || args . provider === "fal-ai" || args . provider === "replicate"
30+ args . provider === "together" || args . provider === "fal-ai" || args . provider === "replicate" || args . provider === "black-forest-labs"
2531 ? {
26- ...omit ( args , [ "inputs" , "parameters" ] ) ,
27- ...args . parameters ,
28- ...( args . provider !== "replicate" ? { response_format : "base64" } : undefined ) ,
29- prompt : args . inputs ,
30- }
32+ ...omit ( args , [ "inputs" , "parameters" ] ) ,
33+ ...args . parameters ,
34+ ...( args . provider !== "replicate" ? { response_format : "base64" } : undefined ) ,
35+ prompt : args . inputs ,
36+ }
3137 : args ;
32- const res = await request < TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration > ( payload , {
38+ const res = await request < TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse > ( payload , {
3339 ...options ,
3440 taskHint : "text-to-image" ,
3541 } ) ;
3642 if ( res && typeof res === "object" ) {
43+ if ( args . provider === "black-forest-labs" && "polling_url" in res && typeof res . polling_url === "string" ) {
44+ return await pollBflResponse ( res . polling_url ) ;
45+ }
3746 if ( args . provider === "fal-ai" && "images" in res && Array . isArray ( res . images ) && res . images [ 0 ] . url ) {
3847 const image = await fetch ( res . images [ 0 ] . url ) ;
3948 return await image . blob ( ) ;
@@ -56,3 +65,23 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
5665 }
5766 return res ;
5867}
68+
69+ async function pollBflResponse ( url : string ) : Promise < Blob > {
70+ const urlObj = new URL ( url ) ;
71+ for ( let step = 0 ; step < 5 ; step ++ ) {
72+ await delay ( 1000 ) ;
73+ console . debug ( `Polling Black Forest Labs API for the result... ${ step + 1 } /5` ) ;
74+ urlObj . searchParams . set ( "uuid" , randomUUID ( ) ) ;
75+ const resp = await fetch ( urlObj , { headers : { "Content-Type" : "application/json" } } ) ;
76+ if ( ! resp . ok ) {
77+ throw new InferenceOutputError ( "Failed to fetch result from black forest labs API" ) ;
78+ }
79+ const payload = await resp . json ( ) ;
80+ if ( typeof payload === "object" && payload && "status" in payload && typeof payload . status === "string" && payload . status === "Ready" && "result" in payload && typeof payload . result === "object" && payload . result && "sample" in payload . result && typeof payload . result . sample === "string" ) {
81+ const image = await fetch ( payload . result . sample ) ;
82+ return await image . blob ( ) ;
83+ }
84+ }
85+ throw new InferenceOutputError ( "Failed to fetch result from black forest labs API" ) ;
86+ }
87+
0 commit comments