|
1 | 1 | import asyncio |
2 | 2 | import datetime |
3 | 3 | import os |
| 4 | +from collections.abc import Sequence |
4 | 5 | from enum import Enum |
5 | 6 |
|
6 | 7 | import reflex as rx |
|
11 | 12 |
|
12 | 13 | DEFAULT_IMAGE = "/default.webp" |
13 | 14 | API_TOKEN_ENV_VAR = "REPLICATE_API_TOKEN" |
| 15 | +MODEL_ENV_VAR = "REPLICATE_MODEL" |
| 16 | +UPSCALE_MODEL_ENV_VAR = "REPLICATE_UPSCALE_MODEL" |
| 17 | + |
| 18 | +DEFAULT_MODEL = "google/imagen-4-fast" |
| 19 | +# philz1337x /clarity-upscaler:029d48aa |
| 20 | +DEFAULT_UPSCALE_MODEL = ( |
| 21 | + "029d48aa21712d6769d7a46729c1edf0e4d41919c70b270785f10abb82989ba5" |
| 22 | +) |
14 | 23 |
|
15 | 24 | CopyLocalState = rx._x.client_state(default=False, var_name="copying") |
16 | 25 |
|
@@ -67,7 +76,7 @@ async def generate_image(self): |
67 | 76 |
|
68 | 77 | # Await the output from the replicate API |
69 | 78 | response = await replicate.predictions.async_create( |
70 | | - "5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f", |
| 79 | + os.environ.get(MODEL_ENV_VAR, DEFAULT_MODEL), |
71 | 80 | input=input, |
72 | 81 | ) |
73 | 82 |
|
@@ -99,8 +108,12 @@ async def generate_image(self): |
99 | 108 | await asyncio.sleep(0.15) |
100 | 109 | async with self: |
101 | 110 | self.upscaled_image = "" |
102 | | - self.output_image = response.output[0] |
103 | | - self.output_list = [] if len(response.output) == 1 else response.output |
| 111 | + if isinstance(response.output, str): |
| 112 | + self.output_image = response.output |
| 113 | + self.output_list = [] |
| 114 | + elif isinstance(response.output, Sequence): |
| 115 | + self.output_image = response.output[0] |
| 116 | + self.output_list = list(response.output[1:]) |
104 | 117 | self._reset_state() |
105 | 118 |
|
106 | 119 | except Exception as e: |
@@ -152,7 +165,7 @@ async def upscale_image(self): |
152 | 165 |
|
153 | 166 | # Await the output from the replicate API |
154 | 167 | response = await replicate.predictions.async_create( |
155 | | - "029d48aa21712d6769d7a46729c1edf0e4d41919c70b270785f10abb82989ba5", |
| 168 | + os.environ.get(UPSCALE_MODEL_ENV_VAR, DEFAULT_UPSCALE_MODEL), |
156 | 169 | input=input, |
157 | 170 | ) |
158 | 171 |
|
|
0 commit comments