Skip to content

Commit c15ac2e

Browse files
committed
implement tiling vae encode support
1 parent 6d84a30 commit c15ac2e

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

ggml_extend.hpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -601,21 +601,31 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
601601
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
602602

603603
// Tiling
604-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
604+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true) {
605605
output = ggml_set_f32(output, 0);
606606

607607
int input_width = (int)input->ne[0];
608608
int input_height = (int)input->ne[1];
609609
int output_width = (int)output->ne[0];
610610
int output_height = (int)output->ne[1];
611+
612+
int input_tile_size, output_tile_size;
613+
if (scaled_out) {
614+
input_tile_size = tile_size;
615+
output_tile_size = tile_size * scale;
616+
} else {
617+
input_tile_size = tile_size * scale;
618+
output_tile_size = tile_size;
619+
}
620+
611621
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
612622

613-
int tile_overlap = (int32_t)(tile_size * tile_overlap_factor);
614-
int non_tile_overlap = tile_size - tile_overlap;
623+
int tile_overlap = (int32_t)(input_tile_size * tile_overlap_factor);
624+
int non_tile_overlap = input_tile_size - tile_overlap;
615625

616626
struct ggml_init_params params = {};
617-
params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk
618-
params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk
627+
params.mem_size += input_tile_size * input_tile_size * input->ne[2] * sizeof(float); // input chunk
628+
params.mem_size += output_tile_size * output_tile_size * output->ne[2] * sizeof(float); // output chunk
619629
params.mem_size += 3 * ggml_tensor_overhead();
620630
params.mem_buffer = NULL;
621631
params.no_alloc = false;
@@ -630,8 +640,8 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
630640
}
631641

632642
// tiling
633-
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1);
634-
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1);
643+
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne[2], 1);
644+
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne[2], 1);
635645
on_processing(input_tile, NULL, true);
636646
int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap);
637647
LOG_INFO("processing %i tiles", num_tiles);
@@ -640,19 +650,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
640650
bool last_y = false, last_x = false;
641651
float last_time = 0.0f;
642652
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) {
643-
if (y + tile_size >= input_height) {
644-
y = input_height - tile_size;
653+
if (y + input_tile_size >= input_height) {
654+
y = input_height - input_tile_size;
645655
last_y = true;
646656
}
647657
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) {
648-
if (x + tile_size >= input_width) {
649-
x = input_width - tile_size;
658+
if (x + input_tile_size >= input_width) {
659+
x = input_width - input_tile_size;
650660
last_x = true;
651661
}
652662
int64_t t1 = ggml_time_ms();
653663
ggml_split_tensor_2d(input, input_tile, x, y);
654664
on_processing(input_tile, output_tile, false);
655-
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
665+
if (scaled_out) {
666+
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
667+
} else {
668+
ggml_merge_tensor_2d(output_tile, output, x / scale, y / scale, tile_overlap / scale);
669+
}
656670
int64_t t2 = ggml_time_ms();
657671
last_time = (t2 - t1) / 1000.0f;
658672
pretty_progress(tile_count, num_tiles, last_time);

stable-diffusion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,12 +1099,12 @@ class StableDiffusionGGML {
10991099
} else {
11001100
ggml_tensor_scale_input(x);
11011101
}
1102-
if (vae_tiling && decode) { // TODO: support tiling vae encode
1102+
if (vae_tiling) {
11031103
// split latent in 32x32 tiles and compute in several steps
11041104
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
11051105
first_stage_model->compute(n_threads, in, decode, &out);
11061106
};
1107-
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
1107+
sd_tiling(x, result, 8, 32, 0.5f, on_tiling, decode);
11081108
} else {
11091109
first_stage_model->compute(n_threads, x, decode, &result);
11101110
}
@@ -1113,12 +1113,12 @@ class StableDiffusionGGML {
11131113
ggml_tensor_scale_output(result);
11141114
}
11151115
} else {
1116-
if (vae_tiling && decode) { // TODO: support tiling vae encode
1116+
if (vae_tiling) {
11171117
// split latent in 64x64 tiles and compute in several steps
11181118
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
11191119
tae_first_stage->compute(n_threads, in, decode, &out);
11201120
};
1121-
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
1121+
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, decode);
11221122
} else {
11231123
tae_first_stage->compute(n_threads, x, decode, &result);
11241124
}

0 commit comments

Comments
 (0)