Skip to content

Commit 5c7e848

Browse files
committed
Set image model and upscale model via env var
1 parent 7f7b3c0 commit 5c7e848

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

ai_image_gen/ai_image_gen/backend/generation.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import datetime
33
import os
4+
from collections.abc import Sequence
45
from enum import Enum
56

67
import reflex as rx
@@ -11,6 +12,14 @@
1112

1213
DEFAULT_IMAGE = "/default.webp"
1314
API_TOKEN_ENV_VAR = "REPLICATE_API_TOKEN"
15+
MODEL_ENV_VAR = "REPLICATE_MODEL"
16+
UPSCALE_MODEL_ENV_VAR = "REPLICATE_UPSCALE_MODEL"
17+
18+
DEFAULT_MODEL = "google/imagen-4-fast"
19+
# philz1337x /clarity-upscaler:029d48aa
20+
DEFAULT_UPSCALE_MODEL = (
21+
"029d48aa21712d6769d7a46729c1edf0e4d41919c70b270785f10abb82989ba5"
22+
)
1423

1524
CopyLocalState = rx._x.client_state(default=False, var_name="copying")
1625

@@ -67,7 +76,7 @@ async def generate_image(self):
6776

6877
# Await the output from the replicate API
6978
response = await replicate.predictions.async_create(
70-
"5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
79+
os.environ.get(MODEL_ENV_VAR, DEFAULT_MODEL),
7180
input=input,
7281
)
7382

@@ -99,8 +108,12 @@ async def generate_image(self):
99108
await asyncio.sleep(0.15)
100109
async with self:
101110
self.upscaled_image = ""
102-
self.output_image = response.output[0]
103-
self.output_list = [] if len(response.output) == 1 else response.output
111+
if isinstance(response.output, str):
112+
self.output_image = response.output
113+
self.output_list = []
114+
elif isinstance(response.output, Sequence):
115+
self.output_image = response.output[0]
116+
self.output_list = list(response.output[1:])
104117
self._reset_state()
105118

106119
except Exception as e:
@@ -152,7 +165,7 @@ async def upscale_image(self):
152165

153166
# Await the output from the replicate API
154167
response = await replicate.predictions.async_create(
155-
"029d48aa21712d6769d7a46729c1edf0e4d41919c70b270785f10abb82989ba5",
168+
os.environ.get(UPSCALE_MODEL_ENV_VAR, DEFAULT_UPSCALE_MODEL),
156169
input=input,
157170
)
158171

0 commit comments

Comments
 (0)