Skip to content

Commit eff2bda

Browse files
committed
experimental: preview during tiled vae decode
1 parent d18a216 commit eff2bda

File tree

2 files changed

+86
-34
lines changed

2 files changed

+86
-34
lines changed

ggml_extend.hpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,11 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input) {
310310
for (int iy = 0; iy < height; iy++) {
311311
for (int ix = 0; ix < width; ix++) {
312312
for (int k = 0; k < channels; k++) {
313-
float value = ggml_tensor_get_f32(input, ix, iy, k);
313+
float value = ggml_tensor_get_f32(input, ix, iy, k);
314+
315+
value = value > 1.0f ? 1.0f : value < 0.0f ? 0.0f
316+
: value;
317+
314318
*(image_data + iy * width * channels + ix * channels + k) = (uint8_t)(value * 255.0f);
315319
}
316320
}
@@ -466,7 +470,8 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
466470
int overlap_x,
467471
int overlap_y,
468472
int x_skip = 0,
469-
int y_skip = 0) {
473+
int y_skip = 0,
474+
bool clear = false) {
470475
int64_t width = input->ne[0];
471476
int64_t height = input->ne[1];
472477
int64_t channels = input->ne[2];
@@ -486,6 +491,10 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
486491
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
487492
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
488493
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;
494+
// clear old value for first pass
495+
if (clear && (x_f_0 >= 1.0f || x == 0) && (y_f_0 >= 1.0f || y == 0)) {
496+
old_value = 0.0f;
497+
}
489498

490499
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
491500
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
@@ -597,9 +606,10 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
597606
}
598607

599608
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
609+
typedef std::function<void(ggml_tensor*)> on_tile_merge;
600610

601611
// Tiling
602-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true) {
612+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true, on_tile_merge on_merge = NULL) {
603613
int input_width = (int)input->ne[0];
604614
int input_height = (int)input->ne[1];
605615
int output_width = (int)output->ne[0];
@@ -713,13 +723,16 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
713723
ggml_split_tensor_2d(input, input_tile, x, y);
714724
on_processing(input_tile, output_tile, false);
715725
if (scaled_out) {
716-
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap_x * scale, tile_overlap_y * scale, dx * scale, dy * scale);
726+
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap_x * scale, tile_overlap_y * scale, dx * scale, dy * scale, true);
717727
} else {
718728
ggml_merge_tensor_2d(output_tile, output, x / scale, y / scale, tile_overlap_x / scale, tile_overlap_y / scale, dx / scale, dy / scale);
719729
}
720730
int64_t t2 = ggml_time_ms();
721731
last_time = (t2 - t1) / 1000.0f;
722732
pretty_progress(tile_count, num_tiles, last_time);
733+
if (on_merge != NULL) {
734+
on_merge(output);
735+
}
723736
tile_count++;
724737
}
725738
last_x = false;

stable-diffusion.cpp

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,21 @@ class StableDiffusionGGML {
806806
sd_set_progress_callback(cb, cbd);
807807
}
808808

809+
const float (*get_latent_rgb_proj(enum SDVersion version))[3] {
810+
if (sd_version_is_sd3(version)) {
811+
return sd3_latent_rgb_proj;
812+
} else if (sd_version_is_flux(version)) {
813+
return flux_latent_rgb_proj;
814+
} else if (sd_version_is_sdxl(version)) {
815+
return sdxl_latent_rgb_proj;
816+
} else if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
817+
return sd_latent_rgb_proj;
818+
} else {
819+
LOG_WARN("No latent to RGB projection known for this model");
820+
return NULL;
821+
}
822+
}
823+
809824
void preview_image(ggml_context* work_ctx,
810825
int step,
811826
struct ggml_tensor* latents,
@@ -820,33 +835,8 @@ class StableDiffusionGGML {
820835
if (preview_mode == SD_PREVIEW_PROJ) {
821836
const float(*latent_rgb_proj)[channel];
822837

823-
if (dim == 16) {
824-
// 16 channels VAE -> Flux or SD3
825-
826-
if (sd_version_is_sd3(version)) {
827-
latent_rgb_proj = sd3_latent_rgb_proj;
828-
} else if (sd_version_is_flux(version)) {
829-
latent_rgb_proj = flux_latent_rgb_proj;
830-
} else {
831-
LOG_WARN("No latent to RGB projection known for this model");
832-
// unknown model
833-
return;
834-
}
835-
836-
} else if (dim == 4) {
837-
// 4 channels VAE
838-
if (sd_version_is_sdxl(version)) {
839-
latent_rgb_proj = sdxl_latent_rgb_proj;
840-
} else if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
841-
latent_rgb_proj = sd_latent_rgb_proj;
842-
} else {
843-
// unknown model
844-
LOG_WARN("No latent to RGB projection known for this model");
845-
return;
846-
}
847-
} else {
848-
LOG_WARN("No latent to RGB projection known for this model");
849-
// unknown latent space
838+
latent_rgb_proj = get_latent_rgb_proj(version);
839+
if (latent_rgb_proj == NULL) {
850840
return;
851841
}
852842
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
@@ -1237,7 +1227,56 @@ class StableDiffusionGGML {
12371227
decode ? (H * 8) : (H / 8), // height
12381228
decode ? 3 : C,
12391229
x->ne[3]); // channels
1240-
int64_t t0 = ggml_time_ms();
1230+
1231+
if (decode && vae_tiling) {
1232+
const float(*latent_rgb_proj)[3];
1233+
latent_rgb_proj = get_latent_rgb_proj(version);
1234+
if (latent_rgb_proj != NULL) {
1235+
uint8_t* data = (uint8_t*)malloc(W * H * 3 * sizeof(uint8_t));
1236+
1237+
preview_latent_image(data, x, latent_rgb_proj, W, H, C / 2);
1238+
1239+
// fill result with upscaled data
1240+
for (int w = 0; w < W; w++) {
1241+
for (int h = 0; h < H; h++) {
1242+
for (int c = 0; c < 3; c++) {
1243+
// int i = (w * H + h) * 3 + c; //wrong
1244+
int i = (h * W + w) * 3 + c;
1245+
float value = data[i] / 255.0f;
1246+
if (!use_tiny_autoencoder) {
1247+
value = value * 2.0f - 1.0f;
1248+
}
1249+
for (int x = 0; x < 8; x++) {
1250+
for (int y = 0; y < 8; y++) {
1251+
ggml_tensor_set_f32(result, value, w * 8 + x, h * 8 + y, c);
1252+
}
1253+
}
1254+
}
1255+
}
1256+
}
1257+
free(data);
1258+
// upscale
1259+
}
1260+
}
1261+
auto preview_cb = sd_get_preview_callback();
1262+
auto on_tile_merged = [&](ggml_tensor* output) {
1263+
if (preview_cb && output->ne[2] == 3) {
1264+
if (!use_tiny_autoencoder) {
1265+
ggml_tensor_scale_output(output);
1266+
}
1267+
sd_image_t image = {
1268+
output->ne[0],
1269+
output->ne[1],
1270+
3,
1271+
sd_tensor_to_image(output)};
1272+
preview_cb(-1, image);
1273+
free(image.data);
1274+
if (!use_tiny_autoencoder) {
1275+
ggml_tensor_scale_input(output);
1276+
}
1277+
}
1278+
};
1279+
int64_t t0 = ggml_time_ms();
12411280
if (!use_tiny_autoencoder) {
12421281
if (decode) {
12431282
ggml_tensor_scale(x, 1.0f / scale_factor);
@@ -1249,7 +1288,7 @@ class StableDiffusionGGML {
12491288
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
12501289
first_stage_model->compute(n_threads, in, decode, &out);
12511290
};
1252-
sd_tiling(x, result, 8, 32, 0.5f, on_tiling, decode);
1291+
sd_tiling(x, result, 8, 32, 0.5f, on_tiling, decode, on_tile_merged);
12531292
} else {
12541293
first_stage_model->compute(n_threads, x, decode, &result);
12551294
}
@@ -1263,7 +1302,7 @@ class StableDiffusionGGML {
12631302
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
12641303
tae_first_stage->compute(n_threads, in, decode, &out);
12651304
};
1266-
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, decode);
1305+
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, decode, on_tile_merged);
12671306
} else {
12681307
tae_first_stage->compute(n_threads, x, decode, &result);
12691308
}

0 commit comments

Comments
 (0)