Skip to content

Commit 570e26a

Browse files
committed
refactor default tile size, limit overlap factor
1 parent 99616b9 commit 570e26a

File tree

1 file changed

+27
-39
lines changed

1 file changed

+27
-39
lines changed

stable-diffusion.cpp

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,25 +1300,28 @@ class StableDiffusionGGML {
13001300
return latent;
13011301
}
13021302

1303-
void get_relative_tile_sizes(int& tile_size_x, int& tile_size_y, float tile_overlap, float rel_size_x, float rel_size_y, int latent_x, int latent_y) {
1304-
// format is AxB, or just A (equivalent to AxA)
1305-
// A and B can be integers (tile size) or floating point
1306-
// floating point <= 1 means simple fraction of the latent dimension
1307-
// floating point > 1 means number of tiles across that dimension
1308-
// a single number gets applied to both
1309-
auto get_tile_factor = [tile_overlap](float factor) {
1310-
if (factor > 1.0)
1311-
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1312-
return factor;
1313-
};
1314-
const int min_tile_dimension = 4;
13151303

1316-
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1317-
tmp_x = std::round(latent_x * get_tile_factor(rel_size_x));
1318-
tmp_y = std::round(latent_y * get_tile_factor(rel_size_y));
1304+
void get_tile_sizes(int& tile_size_x, int& tile_size_y, float& tile_overlap, const sd_tiling_params_t & params, int latent_x, int latent_y) {
1305+
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
1306+
auto get_tile_size = [&](int requested_size, float factor, int latent_size) {
1307+
const int default_tile_size = 32;
1308+
const int min_tile_dimension = 4;
1309+
int tile_size = default_tile_size;
1310+
// rel_size <= 1 means simple fraction of the latent dimension
1311+
// rel_size > 1 means number of tiles across that dimension
1312+
if (params.relative) {
1313+
if (factor > 1.0)
1314+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1315+
tile_size = std::round(latent_size * factor);
1316+
}
1317+
else if (requested_size >= min_tile_dimension) {
1318+
tile_size = requested_size;
1319+
}
1320+
return std::max(std::min(tile_size, latent_size), min_tile_dimension);
1321+
};
13191322

1320-
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1321-
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
1323+
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
1324+
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
13221325

13231326
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
13241327
}
@@ -1336,19 +1339,11 @@ class StableDiffusionGGML {
13361339
}
13371340
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]);
13381341
}
1339-
// TODO: args instead of env for tile size / overlap?
1340-
if (!use_tiny_autoencoder) {
1341-
float tile_overlap = vae_tiling_params.target_overlap;
1342-
int tile_size_x = (vae_tiling_params.tile_size_x >= 4)
1343-
? vae_tiling_params.tile_size_x
1344-
: 32;
1345-
int tile_size_y = (vae_tiling_params.tile_size_y >= 4)
1346-
? vae_tiling_params.tile_size_y
1347-
: 32;
13481342

1349-
if (vae_tiling_params.relative) {
1350-
get_relative_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params.rel_size_x, vae_tiling_params.rel_size_y, W, H);
1351-
}
1343+
if (!use_tiny_autoencoder) {
1344+
float tile_overlap;
1345+
int tile_size_x, tile_size_y;
1346+
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H);
13521347

13531348
// TODO: also use an arg for this one?
13541349
// multiply tile size for encode to keep the compute buffer size consistent
@@ -1493,17 +1488,10 @@ class StableDiffusionGGML {
14931488
}
14941489
int64_t t0 = ggml_time_ms();
14951490
if (!use_tiny_autoencoder) {
1496-
float tile_overlap = vae_tiling_params.target_overlap;
1497-
int tile_size_x = (vae_tiling_params.tile_size_x >= 4)
1498-
? vae_tiling_params.tile_size_x
1499-
: 32;
1500-
int tile_size_y = (vae_tiling_params.tile_size_y >= 4)
1501-
? vae_tiling_params.tile_size_y
1502-
: 32;
1491+
float tile_overlap;
1492+
int tile_size_x, tile_size_y;
1493+
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H);
15031494

1504-
if (vae_tiling_params.relative) {
1505-
get_relative_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params.rel_size_x, vae_tiling_params.rel_size_y, x->ne[0], x->ne[1]);
1506-
}
15071495
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
15081496

15091497
process_latent_out(x);

0 commit comments

Comments
 (0)