Skip to content

Commit 3e74c35

Browse files
committed
apg: first implementation
refactor guidance params in lib main: add apg support add apg settings to image params Fix cfg 1 crash
1 parent 8867e3e commit 3e74c35

File tree

3 files changed

+152
-70
lines changed

3 files changed

+152
-70
lines changed

examples/cli/main.cpp

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,13 @@ struct SDParams {
133133
int upscale_repeats = 1;
134134

135135
std::vector<int> skip_layers = {7, 8, 9};
136-
float slg_scale = 0.;
137-
float skip_layer_start = 0.01;
138-
float skip_layer_end = 0.2;
136+
float slg_scale = 0.0f;
137+
float skip_layer_start = 0.01f;
138+
float skip_layer_end = 0.2f;
139+
140+
float apg_eta = 1.0f;
141+
float apg_momentum = 0.0f;
142+
float apg_norm_threshold = 0.0f;
139143

140144
sd_preview_t preview_method = SD_PREVIEW_NONE;
141145
int preview_interval = 1;
@@ -226,6 +230,9 @@ void print_usage(int argc, const char* argv[]) {
226230
printf(" -p, --prompt [PROMPT] the prompt to render\n");
227231
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
228232
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
233+
printf(" --apg-eta VALUE parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)\n");
234+
printf(" --apg-momentum VALUE CFG update direction momentum for APG (default: 0, recommended: around -0.5)\n");
235+
printf(" --apg-nt, --apg-rescale VALUE CFG update direction norm threshold for APG (default: 0 = disabled, recommended: 4-15)\n");
229236
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
230237
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
231238
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
@@ -648,6 +655,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
648655
break;
649656
}
650657
params.skip_layer_end = std::stof(argv[i]);
658+
} else if (arg == "--apg-eta") {
659+
if (++i >= argc) {
660+
invalid_arg = true;
661+
break;
662+
}
663+
params.apg_eta = std::stof(argv[i]);
664+
} else if (arg == "--apg-momentum") {
665+
if (++i >= argc) {
666+
invalid_arg = true;
667+
break;
668+
}
669+
params.apg_momentum = std::stof(argv[i]);
670+
} else if (arg == "--apg-nt" || arg == "--apg-rescale") {
671+
if (++i >= argc) {
672+
invalid_arg = true;
673+
break;
674+
}
675+
params.apg_norm_threshold = std::stof(argv[i]);
651676
} else if (arg == "--preview") {
652677
if (++i >= argc) {
653678
invalid_arg = true;
@@ -767,6 +792,15 @@ std::string get_image_params(SDParams params, int64_t seed) {
767792
}
768793
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
769794
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
795+
if (params.apg_eta != 1) {
796+
parameter_string += "APG eta: " + std::to_string(params.apg_eta) + ", ";
797+
}
798+
if (params.apg_momentum != 0) {
799+
parameter_string += "CFG momentum: " + std::to_string(params.apg_momentum) + ", ";
800+
}
801+
if (params.apg_norm_threshold != 0) {
802+
parameter_string += "CFG normalization threshold: " + std::to_string(params.apg_norm_threshold) + ", ";
803+
}
770804
if (params.slg_scale != 0 && params.skip_layers.size() != 0) {
771805
parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", ";
772806
parameter_string += "Skip layers: [";
@@ -1020,11 +1054,14 @@ int main(int argc, const char* argv[]) {
10201054
params.style_ratio,
10211055
params.normalize_input,
10221056
params.input_id_images_path.c_str(),
1023-
params.skip_layers.data(),
1024-
params.skip_layers.size(),
1025-
params.slg_scale,
1026-
params.skip_layer_start,
1027-
params.skip_layer_end);
1057+
sd_slg_params_t{params.skip_layers.data(),
1058+
params.skip_layers.size(),
1059+
params.slg_scale,
1060+
params.skip_layer_start,
1061+
params.skip_layer_end},
1062+
sd_apg_params_t{params.apg_eta,
1063+
params.apg_momentum,
1064+
params.apg_norm_threshold});
10281065
} else {
10291066
sd_image_t input_image = {(uint32_t)params.width,
10301067
(uint32_t)params.height,
@@ -1089,11 +1126,14 @@ int main(int argc, const char* argv[]) {
10891126
params.style_ratio,
10901127
params.normalize_input,
10911128
params.input_id_images_path.c_str(),
1092-
params.skip_layers.data(),
1093-
params.skip_layers.size(),
1094-
params.slg_scale,
1095-
params.skip_layer_start,
1096-
params.skip_layer_end);
1129+
sd_slg_params_t{params.skip_layers.data(),
1130+
params.skip_layers.size(),
1131+
params.slg_scale,
1132+
params.skip_layer_start,
1133+
params.skip_layer_end},
1134+
sd_apg_params_t{params.apg_eta,
1135+
params.apg_momentum,
1136+
params.apg_norm_threshold});
10971137
}
10981138
}
10991139

stable-diffusion.cpp

Lines changed: 81 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ const char* sampling_methods_str[] = {
5050
"iPNDM_v",
5151
"LCM",
5252
"DDIM \"trailing\"",
53-
"TCD"
54-
};
53+
"TCD"};
5554

5655
/*================================================== Helper Functions ================================================*/
5756

@@ -696,7 +695,7 @@ class StableDiffusionGGML {
696695
float curr_multiplier = kv.second;
697696
lora_state_diff[lora_name] -= curr_multiplier;
698697
}
699-
698+
700699
size_t rm = lora_state_diff.size() - lora_state.size();
701700
if (rm != 0) {
702701
LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm);
@@ -918,16 +917,15 @@ class StableDiffusionGGML {
918917
float min_cfg,
919918
float cfg_scale,
920919
float guidance,
921-
float eta,
920+
float eta,
922921
sample_method_t method,
923922
const std::vector<float>& sigmas,
924923
int start_merge_step,
925924
SDCondition id_cond,
926-
std::vector<int> skip_layers = {},
927-
float slg_scale = 0,
928-
float skip_layer_start = 0.01,
929-
float skip_layer_end = 0.2,
930-
ggml_tensor* noise_mask = nullptr) {
925+
sd_slg_params_t slg_params = {NULL, 0, 0, 0, 0},
926+
sd_apg_params_t apg_params = {1, 0, 0},
927+
ggml_tensor* noise_mask = nullptr) {
928+
std::vector<int> skip_layers(slg_params.skip_layers, slg_params.skip_layers + slg_params.skip_layers_count);
931929
size_t steps = sigmas.size() - 1;
932930
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
933931
// print_ggml_tensor(noise);
@@ -938,7 +936,7 @@ class StableDiffusionGGML {
938936
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, noise);
939937

940938
bool has_unconditioned = cfg_scale != 1.0 && uncond.c_crossattn != NULL;
941-
bool has_skiplayer = slg_scale != 0.0 && skip_layers.size() > 0;
939+
bool has_skiplayer = slg_params.scale != 0.0 && skip_layers.size() > 0;
942940

943941
// denoise wrapper
944942
struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x);
@@ -959,7 +957,7 @@ class StableDiffusionGGML {
959957
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
960958

961959
struct ggml_tensor* preview_tensor = NULL;
962-
auto sd_preview_mode = sd_get_preview_mode();
960+
auto sd_preview_mode = sd_get_preview_mode();
963961
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
964962
preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
965963
(denoised->ne[0] * 8),
@@ -968,6 +966,10 @@ class StableDiffusionGGML {
968966
denoised->ne[3]);
969967
}
970968

969+
std::vector<float> apg_momentum_buffer;
970+
if (apg_params.momentum != 0)
971+
apg_momentum_buffer.resize((size_t)ggml_nelements(denoised));
972+
971973
auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
972974
if (step == 1) {
973975
pretty_progress(0, (int)steps, 0);
@@ -1048,7 +1050,7 @@ class StableDiffusionGGML {
10481050
}
10491051

10501052
int step_count = sigmas.size();
1051-
bool is_skiplayer_step = has_skiplayer && step > (int)(skip_layer_start * step_count) && step < (int)(skip_layer_end * step_count);
1053+
bool is_skiplayer_step = has_skiplayer && step > (int)(slg_params.skip_layer_start * step_count) && step < (int)(slg_params.skip_layer_end * step_count);
10521054
float* skip_layer_data = NULL;
10531055
if (is_skiplayer_step) {
10541056
LOG_DEBUG("Skipping layers at step %d\n", step);
@@ -1072,6 +1074,52 @@ class StableDiffusionGGML {
10721074
float* vec_input = (float*)input->data;
10731075
float* positive_data = (float*)out_cond->data;
10741076
int ne_elements = (int)ggml_nelements(denoised);
1077+
1078+
float* deltas = vec_denoised;
1079+
1080+
// https://arxiv.org/pdf/2410.02416
1081+
float apg_scale_factor = 1.;
1082+
float diff_norm = 0;
1083+
float cond_norm_sq = 0;
1084+
float dot = 0;
1085+
if (has_unconditioned) {
1086+
for (int i = 0; i < ne_elements; i++) {
1087+
float delta = positive_data[i] - negative_data[i];
1088+
if (apg_params.momentum != 0) {
1089+
delta += apg_params.momentum * apg_momentum_buffer[i];
1090+
apg_momentum_buffer[i] = delta;
1091+
}
1092+
if (apg_params.norm_treshold > 0) {
1093+
diff_norm += delta * delta;
1094+
}
1095+
if (apg_params.eta != 1.0f) {
1096+
cond_norm_sq += positive_data[i] * positive_data[i];
1097+
dot += positive_data[i] * delta;
1098+
}
1099+
deltas[i] = delta;
1100+
}
1101+
if (apg_params.norm_treshold > 0) {
1102+
diff_norm = std::sqrtf(diff_norm);
1103+
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
1104+
}
1105+
if (apg_params.eta != 1.0f) {
1106+
dot *= apg_scale_factor;
1107+
// pre-normalize (avoids one square root and ne_elements extra divs)
1108+
dot /= cond_norm_sq;
1109+
}
1110+
1111+
for (int i = 0; i < ne_elements; i++) {
1112+
deltas[i] *= apg_scale_factor;
1113+
if (apg_params.eta != 1.0f) {
1114+
float apg_parallel = dot * positive_data[i];
1115+
float apg_orthogonal = deltas[i] - apg_parallel;
1116+
1117+
// tweak deltas
1118+
deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
1119+
}
1120+
}
1121+
}
1122+
10751123
for (int i = 0; i < ne_elements; i++) {
10761124
float latent_result = positive_data[i];
10771125
if (has_unconditioned) {
@@ -1081,11 +1129,13 @@ class StableDiffusionGGML {
10811129
int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2];
10821130
float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3);
10831131
} else {
1084-
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
1132+
float delta = deltas[i];
1133+
1134+
latent_result = positive_data[i] + (cfg_scale - 1) * delta;
10851135
}
10861136
}
10871137
if (is_skiplayer_step) {
1088-
latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale;
1138+
latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_params.scale;
10891139
}
10901140
// v = latent_result, eps = latent_result
10911141
// denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
@@ -1108,7 +1158,7 @@ class StableDiffusionGGML {
11081158
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
11091159
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
11101160
}
1111-
auto sd_preview_cb = sd_get_preview_callback();
1161+
auto sd_preview_cb = sd_get_preview_callback();
11121162
auto sd_preview_mode = sd_get_preview_mode();
11131163
if (sd_preview_cb != NULL) {
11141164
if (step % sd_get_preview_interval() == 0) {
@@ -1131,7 +1181,8 @@ class StableDiffusionGGML {
11311181
}
11321182

11331183
// ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding
1134-
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* moments) {
1184+
ggml_tensor*
1185+
get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* moments) {
11351186
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
11361187
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
11371188
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
@@ -1338,11 +1389,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13381389
float style_ratio,
13391390
bool normalize_input,
13401391
std::string input_id_images_path,
1341-
std::vector<int> skip_layers = {},
1342-
float slg_scale = 0,
1343-
float skip_layer_start = 0.01,
1344-
float skip_layer_end = 0.2,
1345-
ggml_tensor* masked_image = NULL) {
1392+
sd_slg_params_t slg_params,
1393+
sd_apg_params_t apg_params,
1394+
ggml_tensor* masked_image = NULL) {
13461395
if (seed < 0) {
13471396
// Generally, when using the provided command line, the seed is always >0.
13481397
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1595,10 +1644,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15951644
sigmas,
15961645
start_merge_step,
15971646
id_cond,
1598-
skip_layers,
1599-
slg_scale,
1600-
skip_layer_start,
1601-
skip_layer_end,
1647+
slg_params,
1648+
apg_params,
16021649
noise_mask);
16031650

16041651
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
@@ -1668,12 +1715,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16681715
float style_ratio,
16691716
bool normalize_input,
16701717
const char* input_id_images_path_c_str,
1671-
int* skip_layers = NULL,
1672-
size_t skip_layers_count = 0,
1673-
float slg_scale = 0,
1674-
float skip_layer_start = 0.01,
1675-
float skip_layer_end = 0.2) {
1676-
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
1718+
sd_slg_params_t slg_params,
1719+
sd_apg_params_t apg_params) {
16771720
LOG_DEBUG("txt2img %dx%d", width, height);
16781721
if (sd_ctx == NULL) {
16791722
return NULL;
@@ -1751,10 +1794,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
17511794
style_ratio,
17521795
normalize_input,
17531796
input_id_images_path_c_str,
1754-
skip_layers_vec,
1755-
slg_scale,
1756-
skip_layer_start,
1757-
skip_layer_end,
1797+
slg_params,
1798+
apg_params,
17581799
NULL);
17591800

17601801
size_t t1 = ggml_time_ms();
@@ -1785,12 +1826,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17851826
float style_ratio,
17861827
bool normalize_input,
17871828
const char* input_id_images_path_c_str,
1788-
int* skip_layers = NULL,
1789-
size_t skip_layers_count = 0,
1790-
float slg_scale = 0,
1791-
float skip_layer_start = 0.01,
1792-
float skip_layer_end = 0.2) {
1793-
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
1829+
sd_slg_params_t slg_params,
1830+
sd_apg_params_t apg_params) {
17941831
LOG_DEBUG("img2img %dx%d", width, height);
17951832
if (sd_ctx == NULL) {
17961833
return NULL;
@@ -1932,10 +1969,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
19321969
style_ratio,
19331970
normalize_input,
19341971
input_id_images_path_c_str,
1935-
skip_layers_vec,
1936-
slg_scale,
1937-
skip_layer_start,
1938-
skip_layer_end,
1972+
slg_params,
1973+
apg_params,
19391974
masked_image);
19401975

19411976
size_t t2 = ggml_time_ms();
@@ -2039,8 +2074,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
20392074
sigmas,
20402075
-1,
20412076
SDCondition(NULL, NULL, NULL),
2042-
{},
2043-
0, 0, 0, NULL);
2077+
{}, {}, NULL);
20442078

20452079
int64_t t2 = ggml_time_ms();
20462080
LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000);

0 commit comments

Comments
 (0)