Skip to content

Commit 07a86bf

Browse files
committed
Move app scripts into chart repo
1 parent 6e61451 commit 07a86bf

File tree

11 files changed

+403
-2
lines changed

11 files changed

+403
-2
lines changed

charts/flux-image-gen/templates/tests/gradio-api.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ spec:
1212
image: "{{ $.Values.image.repository }}:{{ $.Values.image.tag | default $.Chart.AppVersion }}"
1313
command:
1414
- python
15-
- stackhpc-app/test_client.py
15+
- test_client.py
1616
env:
1717
- name: GRADIO_HOST
1818
value: {{ printf "http://%s-ui.%s.svc:%v" (include "flux-image-gen.fullname" .) .Release.Namespace .Values.ui.service.port }}

charts/flux-image-gen/templates/ui/deployment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ spec:
3939
imagePullPolicy: {{ .Values.image.pullPolicy }}
4040
command:
4141
- python
42-
- stackhpc-app/gradio_ui.py
42+
- gradio_ui.py
4343
ports:
4444
- name: http
4545
containerPort: {{ .Values.ui.service.port }}

web-apps/flux-image-gen/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
output/

web-apps/flux-image-gen/Dockerfile

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
FROM python:3.11
2+
3+
# https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo
4+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
5+
6+
7+
ARG DIR=flux-image-gen
8+
9+
COPY $DIR/requirements.txt requirements.txt
10+
RUN pip install --no-cache-dir -r requirements.txt
11+
12+
COPY purge-google-fonts.sh .
13+
RUN bash purge-google-fonts.sh
14+
15+
WORKDIR /app
16+
17+
COPY $DIR/*.py .
18+
19+
COPY $DIR/gradio_config.yaml .
20+
21+
COPY $DIR/test-image.jpg .
22+
23+
ENTRYPOINT ["fastapi", "run", "api_server.py"]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import io
2+
import os
3+
import sys
4+
import torch
5+
6+
from fastapi import FastAPI
7+
from fastapi.responses import Response, JSONResponse
8+
from PIL import Image
9+
from pydantic import BaseModel
10+
11+
from image_gen import FluxGenerator
12+
13+
# Detect if app is run using `fastapi dev ...`
14+
DEV_MODE = sys.argv[1] == "dev"
15+
16+
app = FastAPI()
17+
18+
device = "cuda" if torch.cuda.is_available() else "cpu"
19+
model = os.environ.get("FLUX_MODEL_NAME", "flux-schnell")
20+
if not DEV_MODE:
21+
print("Loading model", model)
22+
generator = FluxGenerator(model, device, offload=False)
23+
24+
25+
class ImageGenInput(BaseModel):
26+
width: int
27+
height: int
28+
num_steps: int
29+
guidance: float
30+
seed: int
31+
prompt: str
32+
add_sampling_metadata: bool
33+
34+
35+
@app.get("/model")
36+
async def get_model():
37+
return {"model": model}
38+
39+
40+
@app.post("/generate")
41+
async def generate_image(input: ImageGenInput):
42+
if DEV_MODE:
43+
# For quicker testing or when GPU hardware not available
44+
fn = "test-image.jpg"
45+
seed = "dev"
46+
image = Image.open(fn)
47+
# Uncomment to test error handling
48+
# return JSONResponse({"error": {"message": "Dev mode error test", "seed": "not-so-random"}}, status_code=400)
49+
else:
50+
# Main image generation functionality
51+
image, seed, msg = generator.generate_image(
52+
input.width,
53+
input.height,
54+
input.num_steps,
55+
input.guidance,
56+
input.seed,
57+
input.prompt,
58+
add_sampling_metadata=input.add_sampling_metadata,
59+
)
60+
if not image:
61+
return JSONResponse({"error": {"message": msg, "seed": seed}}, status_code=400)
62+
# Convert image to bytes response
63+
buffer = io.BytesIO()
64+
image.save(buffer, format="jpeg")
65+
bytes = buffer.getvalue()
66+
return Response(bytes, media_type="image/jpeg", headers={"x-flux-seed": seed})
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
models:
2+
- name: flux-schnell
3+
address: http://localhost:8000
4+
example_prompt: |
5+
Yoda riding a skateboard.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import io
2+
import os
3+
import httpx
4+
import uuid
5+
import pathlib
6+
import yaml
7+
8+
import gradio as gr
9+
from pydantic import BaseModel, HttpUrl
10+
from PIL import Image, ExifTags
11+
from typing import List
12+
from urllib.parse import urljoin
13+
14+
15+
class Model(BaseModel):
16+
name: str
17+
address: HttpUrl
18+
19+
class AppSettings(BaseModel):
20+
models: List[Model]
21+
example_prompt: str
22+
23+
24+
settings_path = pathlib.Path("/etc/gradio-app/gradio_config.yaml")
25+
if not settings_path.exists():
26+
print("No settings overrides found at", settings_path)
27+
settings_path = "./gradio_config.yaml"
28+
print("Using settings from", settings_path)
29+
with open(settings_path, "r") as file:
30+
settings = AppSettings(**yaml.safe_load(file))
31+
print("App config:", settings.model_dump())
32+
33+
MODELS = {m.name: m.address for m in settings.models}
34+
MODEL_NAMES = list(MODELS.keys())
35+
36+
# Disable analytics for GDPR compliance
37+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
38+
39+
def save_image(model_name: str, prompt: str, seed: int, add_sampling_metadata: bool, image: Image.Image):
40+
filename = f"output/gradio/{uuid.uuid4()}.jpg"
41+
os.makedirs(os.path.dirname(filename), exist_ok=True)
42+
exif_data = Image.Exif()
43+
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
44+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
45+
exif_data[ExifTags.Base.Model] = model_name
46+
if add_sampling_metadata:
47+
exif_data[ExifTags.Base.ImageDescription] = prompt
48+
image.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
49+
return filename
50+
51+
52+
async def generate_image(
53+
model_name: str,
54+
width: int,
55+
height: int,
56+
num_steps: int,
57+
guidance: float,
58+
seed: int,
59+
prompt: str,
60+
add_sampling_metadata: bool,
61+
):
62+
url = urljoin(str(MODELS[model_name]), "/generate")
63+
data = {
64+
"width": width,
65+
"height": height,
66+
"num_steps": num_steps,
67+
"guidance": guidance,
68+
"seed": seed,
69+
"prompt": prompt,
70+
"add_sampling_metadata": add_sampling_metadata,
71+
}
72+
async with httpx.AsyncClient(timeout=60) as client:
73+
try:
74+
response = await client.post(url, json=data)
75+
except httpx.ConnectError:
76+
raise gr.Error("Model backend unavailable")
77+
if response.status_code == 400:
78+
data = response.json()
79+
if "error" in data and "message" in data["error"]:
80+
message = data["error"]["message"]
81+
if "seed" in data["error"]:
82+
message += f" (seed: {data['error']['seed']})"
83+
raise gr.Error(message)
84+
try:
85+
response.raise_for_status()
86+
except httpx.HTTPStatusError as err:
87+
# Raise a generic error message to avoid leaking unwanted details
88+
# Admin should consult API logs for more info
89+
raise gr.Error(f"Backend error (HTTP {err.response.status_code})")
90+
image = Image.open(io.BytesIO(response.content))
91+
seed = response.headers.get("x-flux-seed", "unknown")
92+
filename = save_image(model_name, prompt, seed, add_sampling_metadata, image)
93+
94+
return image, seed, filename, None
95+
96+
97+
with gr.Blocks() as demo:
98+
gr.Markdown("# Flux Image Generation Demo")
99+
100+
with gr.Row():
101+
with gr.Column():
102+
model = gr.Dropdown(MODEL_NAMES, value=MODEL_NAMES[0], label="Model", interactive=len(MODEL_NAMES) > 1)
103+
prompt = gr.Textbox(label="Prompt", value=settings.example_prompt)
104+
105+
with gr.Accordion("Advanced Options", open=False):
106+
# TODO: Make min/max slide values configurable
107+
width = gr.Slider(128, 8192, 1360, step=16, label="Width")
108+
height = gr.Slider(128, 8192, 768, step=16, label="Height")
109+
num_steps = gr.Slider(1, 50, 4 if model.value == "flux-schnell" else 50, step=1, label="Number of steps")
110+
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not model.value == "flux-schnell")
111+
seed = gr.Textbox("-1", label="Seed (-1 for random)")
112+
add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True)
113+
114+
generate_btn = gr.Button("Generate")
115+
116+
with gr.Column():
117+
output_image = gr.Image(label="Generated Image")
118+
seed_output = gr.Textbox(label="Used Seed")
119+
warning_text = gr.Textbox(label="Warning", visible=False)
120+
download_btn = gr.File(label="Download full-resolution")
121+
122+
generate_btn.click(
123+
fn=generate_image,
124+
inputs=[model, width, height, num_steps, guidance, seed, prompt, add_sampling_metadata],
125+
outputs=[output_image, seed_output, download_btn, warning_text],
126+
)
127+
demo.launch(enable_monitoring=False)

0 commit comments

Comments
 (0)