Skip to content

Commit 4798be9

Browse files
committed
add safeguards and include overlapping factor for SD_TILE_SIZE
1 parent fe84190 commit 4798be9

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

stable-diffusion.cpp

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,9 +1456,18 @@ class StableDiffusionGGML {
14561456
if (SD_TILE_SIZE != nullptr) {
14571457
// format is AxB, or just A (equivalent to AxA)
14581458
// A and B can be integers (tile size) or floating point
1459-
// floating point <= 1 means fraction of the latent dimension
1460-
// floating point > 1 means number of tiles in that dimension
1461-
// a single number gets applied to both dimensions
1459+
// floating point <= 1 means simple fraction of the latent dimension
1460+
// floating point > 1 means number of tiles across that dimension
1461+
// a single number gets applied to both
1462+
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1463+
float factor = std::stof(factor_str);
1464+
if (factor > 1.0)
1465+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1466+
return factor;
1467+
};
1468+
const int latent_x = W / (decode ? 1 : 8);
1469+
const int latent_y = H / (decode ? 1 : 8);
1470+
const int min_tile_dimension = 4;
14621471
std::string sd_tile_size_str = SD_TILE_SIZE;
14631472
size_t x_pos = sd_tile_size_str.find('x');
14641473
try {
@@ -1467,44 +1476,30 @@ class StableDiffusionGGML {
14671476
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
14681477
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
14691478
if (tile_x_str.find('.') != std::string::npos) {
1470-
float tile_factor = std::stof(tile_x_str);
1471-
if (tile_factor > 0.0) {
1472-
if (tile_factor > 1.0)
1473-
tile_factor = 1.0 / tile_factor;
1474-
tmp_x = (W / (decode ? 1 : 8)) * tile_factor;
1475-
}
1479+
tmp_x = latent_x * get_tile_factor(tile_x_str);
14761480
}
14771481
else {
14781482
tmp_x = std::stoi(tile_x_str);
14791483
}
14801484
if (tile_y_str.find('.') != std::string::npos) {
1481-
float tile_factor = std::stof(tile_y_str);
1482-
if (tile_factor > 0.0) {
1483-
if (tile_factor > 1.0)
1484-
tile_factor = 1.0 / tile_factor;
1485-
tmp_y = (H / (decode ? 1 : 8)) * tile_factor;
1486-
}
1485+
tmp_y = latent_y * get_tile_factor(tile_y_str);
14871486
}
14881487
else {
14891488
tmp_y = std::stoi(tile_y_str);
14901489
}
14911490
}
14921491
else {
14931492
if (sd_tile_size_str.find('.') != std::string::npos) {
1494-
float tile_factor = std::stof(sd_tile_size_str);
1495-
if (tile_factor > 0) {
1496-
if (tile_factor > 1.0)
1497-
tile_factor = 1.0 / tile_factor;
1498-
tmp_x = (W / (decode ? 1 : 8)) * tile_factor;
1499-
tmp_y = (H / (decode ? 1 : 8)) * tile_factor;
1500-
}
1493+
float tile_factor = get_tile_factor(sd_tile_size_str);
1494+
tmp_x = latent_x * tile_factor;
1495+
tmp_y = latent_y * tile_factor;
15011496
}
15021497
else {
15031498
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
15041499
}
15051500
}
1506-
tile_size_x = tmp_x;
1507-
tile_size_y = tmp_y;
1501+
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1502+
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
15081503
} catch (const std::invalid_argument&) {
15091504
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
15101505
} catch (const std::out_of_range&) {

0 commit comments

Comments
 (0)