Skip to content

Commit 9c81f61

Browse files
committed
fast latent image preview
1 parent 4570715 commit 9c81f61

File tree

3 files changed

+138
-7
lines changed

3 files changed

+138
-7
lines changed

examples/cli/main.cpp

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,125 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
765765
fflush(out_stream);
766766
}
767767

768+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
769+
const float flux_latent_rgb_proj[16][3] = {
770+
{-0.0346, 0.0244, 0.0681},
771+
{0.0034, 0.0210, 0.0687},
772+
{0.0275, -0.0668, -0.0433},
773+
{-0.0174, 0.0160, 0.0617},
774+
{0.0859, 0.0721, 0.0329},
775+
{0.0004, 0.0383, 0.0115},
776+
{0.0405, 0.0861, 0.0915},
777+
{-0.0236, -0.0185, -0.0259},
778+
{-0.0245, 0.0250, 0.1180},
779+
{0.1008, 0.0755, -0.0421},
780+
{-0.0515, 0.0201, 0.0011},
781+
{0.0428, -0.0012, -0.0036},
782+
{0.0817, 0.0765, 0.0749},
783+
{-0.1264, -0.0522, -0.1103},
784+
{-0.0280, -0.0881, -0.0499},
785+
{-0.1262, -0.0982, -0.0778}};
786+
787+
// https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L228-L246
788+
const float sd3_latent_rgb_proj[16][3] = {
789+
{-0.0645, 0.0177, 0.1052},
790+
{0.0028, 0.0312, 0.0650},
791+
{0.1848, 0.0762, 0.0360},
792+
{0.0944, 0.0360, 0.0889},
793+
{0.0897, 0.0506, -0.0364},
794+
{-0.0020, 0.1203, 0.0284},
795+
{0.0855, 0.0118, 0.0283},
796+
{-0.0539, 0.0658, 0.1047},
797+
{-0.0057, 0.0116, 0.0700},
798+
{-0.0412, 0.0281, -0.0039},
799+
{0.1106, 0.1171, 0.1220},
800+
{-0.0248, 0.0682, -0.0481},
801+
{0.0815, 0.0846, 0.1207},
802+
{-0.0120, -0.0055, -0.0867},
803+
{-0.0749, -0.0634, -0.0456},
804+
{-0.1418, -0.1457, -0.1259},
805+
};
806+
807+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
808+
const float sdxl_latent_rgb_proj[4][3] = {
809+
{0.3651, 0.4232, 0.4341},
810+
{-0.2533, -0.0042, 0.1068},
811+
{0.1076, 0.1111, -0.0362},
812+
{-0.3165, -0.2492, -0.2188}};
813+
814+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
815+
const float sd_latent_rgb_proj[4][3]{
816+
{0.3512, 0.2297, 0.3227},
817+
{0.3250, 0.4974, 0.2350},
818+
{-0.2829, 0.1762, 0.2721},
819+
{-0.2120, -0.2616, -0.7177}};
820+
821+
void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
822+
const int channel = 3;
823+
int width = latents->ne[0];
824+
int height = latents->ne[1];
825+
int dim = latents->ne[2];
826+
827+
const float(*latent_rgb_proj)[channel];
828+
829+
if (dim == 16) {
830+
// 16 channels VAE -> Flux or SD3
831+
832+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B /* || version == VERSION_SD3_5_2B*/) {
833+
latent_rgb_proj = sd3_latent_rgb_proj;
834+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
835+
latent_rgb_proj = flux_latent_rgb_proj;
836+
} else {
837+
// unknown model
838+
return;
839+
}
840+
841+
} else if (dim == 4) {
842+
// 4 channels VAE
843+
if (version == VERSION_SDXL) {
844+
latent_rgb_proj = sdxl_latent_rgb_proj;
845+
} else if (version == VERSION_SD1 || version == VERSION_SD2) {
846+
latent_rgb_proj = sd_latent_rgb_proj;
847+
} else {
848+
// unknown model
849+
return;
850+
}
851+
} else {
852+
// unknown latent space
853+
return;
854+
}
855+
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
856+
int data_head = 0;
857+
for (int j = 0; j < height; j++) {
858+
for (int i = 0; i < width; i++) {
859+
int latent_id = (i * latents->nb[0] + j * latents->nb[1]);
860+
float r = 0, g = 0, b = 0;
861+
for (int d = 0; d < dim; d++) {
862+
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[2]);
863+
r += value * latent_rgb_proj[d][0];
864+
g += value * latent_rgb_proj[d][1];
865+
b += value * latent_rgb_proj[d][2];
866+
}
867+
868+
// change range
869+
r = r * .5 + .5;
870+
g = g * .5 + .5;
871+
b = b * .5 + .5;
872+
873+
// clamp rgb values to [0,1] range
874+
r = r >= 0 ? r <= 1 ? r : 1 : 0;
875+
g = g >= 0 ? g <= 1 ? g : 1 : 0;
876+
b = b >= 0 ? b <= 1 ? b : 1 : 0;
877+
878+
data[data_head++] = (uint8_t)(r * 255.);
879+
data[data_head++] = (uint8_t)(g * 255.);
880+
data[data_head++] = (uint8_t)(b * 255.);
881+
}
882+
}
883+
stbi_write_png("latent-preview.png", width, height, channel, data, 0);
884+
free(data);
885+
}
886+
768887
int main(int argc, const char* argv[]) {
769888
SDParams params;
770889

@@ -930,7 +1049,8 @@ int main(int argc, const char* argv[]) {
9301049
params.skip_layers.size(),
9311050
params.slg_scale,
9321051
params.skip_layer_start,
933-
params.skip_layer_end);
1052+
params.skip_layer_end,
1053+
step_callback);
9341054
} else {
9351055
sd_image_t input_image = {(uint32_t)params.width,
9361056
(uint32_t)params.height,

stable-diffusion.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,8 @@ class StableDiffusionGGML {
782782
std::vector<int> skip_layers = {},
783783
float slg_scale = 0,
784784
float skip_layer_start = 0.01,
785-
float skip_layer_end = 0.2) {
785+
float skip_layer_end = 0.2,
786+
std::function<void(int, ggml_tensor*, SDVersion)> step_callback = nullptr) {
786787
size_t steps = sigmas.size() - 1;
787788
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
788789
// print_ggml_tensor(noise);
@@ -941,6 +942,9 @@ class StableDiffusionGGML {
941942
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
942943
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
943944
}
945+
if (step_callback != nullptr) {
946+
step_callback(step, denoised, version);
947+
}
944948
return denoised;
945949
};
946950

@@ -1164,7 +1168,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
11641168
std::vector<int> skip_layers = {},
11651169
float slg_scale = 0,
11661170
float skip_layer_start = 0.01,
1167-
float skip_layer_end = 0.2) {
1171+
float skip_layer_end = 0.2,
1172+
std::function<void(int, ggml_tensor*, SDVersion)> step_callback = nullptr) {
11681173
if (seed < 0) {
11691174
// Generally, when using the provided command line, the seed is always >0.
11701175
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1386,7 +1391,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13861391
skip_layers,
13871392
slg_scale,
13881393
skip_layer_start,
1389-
skip_layer_end);
1394+
skip_layer_end,
1395+
step_callback);
13901396
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
13911397
// print_ggml_tensor(x_0);
13921398
int64_t sampling_end = ggml_time_ms();
@@ -1457,7 +1463,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14571463
size_t skip_layers_count = 0,
14581464
float slg_scale = 0,
14591465
float skip_layer_start = 0.01,
1460-
float skip_layer_end = 0.2) {
1466+
float skip_layer_end = 0.2,
1467+
step_callback_t step_callback) {
14611468
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
14621469
LOG_DEBUG("txt2img %dx%d", width, height);
14631470
if (sd_ctx == NULL) {
@@ -1530,7 +1537,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
15301537
skip_layers_vec,
15311538
slg_scale,
15321539
skip_layer_start,
1533-
skip_layer_end);
1540+
skip_layer_end,
1541+
step_callback);
15341542

15351543
size_t t1 = ggml_time_ms();
15361544

stable-diffusion.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
149149

150150
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
151151

152+
typedef void (*step_callback_t)(int, struct ggml_tensor*, enum SDVersion);
153+
152154
SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
153155
const char* prompt,
154156
const char* negative_prompt,
@@ -170,7 +172,8 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
170172
size_t skip_layers_count,
171173
float slg_scale,
172174
float skip_layer_start,
173-
float skip_layer_end);
175+
float skip_layer_end,
176+
step_callback_t step_callback = NULL);
174177

175178
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
176179
sd_image_t init_image,

0 commit comments

Comments
 (0)