Skip to content

Commit 316937b

Browse files
committed
apg: first implementation
refactor guidance params in lib main: add apg support add apg settings to image params Fix cfg 1 crash Fix CI build
1 parent 1896b28 commit 316937b

File tree

3 files changed

+92
-11
lines changed

3 files changed

+92
-11
lines changed

examples/cli/main.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,14 @@ struct SDParams {
102102
int upscale_repeats = 1;
103103

104104
std::vector<int> skip_layers = {7, 8, 9};
105-
float slg_scale = 0.f;
105+
float slg_scale = 0.0f;
106106
float skip_layer_start = 0.01f;
107107
float skip_layer_end = 0.2f;
108108

109+
float apg_eta = 1.0f;
110+
float apg_momentum = 0.0f;
111+
float apg_norm_threshold = 0.0f;
112+
109113
bool chroma_use_dit_mask = true;
110114
bool chroma_use_t5_mask = false;
111115
int chroma_t5_mask_pad = 1;
@@ -204,6 +208,9 @@ void print_usage(int argc, const char* argv[]) {
204208
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
205209
printf(" --img-cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n");
206210
printf(" --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5)\n");
211+
printf(" --apg-eta VALUE parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)\n");
212+
printf(" --apg-momentum VALUE CFG update direction momentum for APG (default: 0, recommended: around -0.5)\n");
213+
printf(" --apg-nt, --apg-rescale VALUE CFG update direction norm threshold for APG (default: 0 = disabled, recommended: 4-15)\n");
207214
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
208215
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
209216
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
@@ -660,6 +667,15 @@ std::string get_image_params(SDParams params, int64_t seed) {
660667
}
661668
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
662669
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
670+
if (params.apg_eta != 1) {
671+
parameter_string += "APG eta: " + std::to_string(params.apg_eta) + ", ";
672+
}
673+
if (params.apg_momentum != 0) {
674+
parameter_string += "CFG momentum: " + std::to_string(params.apg_momentum) + ", ";
675+
}
676+
if (params.apg_norm_threshold != 0) {
677+
parameter_string += "CFG normalization threshold: " + std::to_string(params.apg_norm_threshold) + ", ";
678+
}
663679
if (params.slg_scale != 0 && params.skip_layers.size() != 0) {
664680
parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", ";
665681
parameter_string += "Skip layers: [";
@@ -967,7 +983,7 @@ int main(int argc, const char* argv[]) {
967983
params.input_id_images_path.c_str(),
968984
};
969985

970-
results = generate_image(sd_ctx, &img_gen_params);
986+
results = generate_image(sd_ctx, &img_gen_params, {params.apg_eta, params.apg_momentum, params.apg_norm_threshold});
971987
expected_num_results = params.batch_count;
972988
} else if (params.mode == VID_GEN) {
973989
sd_vid_gen_params_t vid_gen_params = {

stable-diffusion.cpp

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ class StableDiffusionGGML {
848848
int start_merge_step,
849849
SDCondition id_cond,
850850
std::vector<ggml_tensor*> ref_latents = {},
851+
sd_apg_params_t apg_params = {1, 0, 0},
851852
ggml_tensor* denoise_mask = nullptr) {
852853
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
853854

@@ -909,6 +910,10 @@ class StableDiffusionGGML {
909910
}
910911
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
911912

913+
std::vector<float> apg_momentum_buffer;
914+
if (apg_params.momentum != 0)
915+
apg_momentum_buffer.resize((size_t)ggml_nelements(denoised));
916+
912917
auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
913918
if (step == 1) {
914919
pretty_progress(0, (int)steps, 0);
@@ -1034,6 +1039,56 @@ class StableDiffusionGGML {
10341039
float* vec_input = (float*)input->data;
10351040
float* positive_data = (float*)out_cond->data;
10361041
int ne_elements = (int)ggml_nelements(denoised);
1042+
1043+
float* deltas = vec_denoised;
1044+
1045+
// https://arxiv.org/pdf/2410.02416
1046+
float apg_scale_factor = 1.;
1047+
float diff_norm = 0;
1048+
float cond_norm_sq = 0;
1049+
float dot = 0;
1050+
if (has_unconditioned || has_img_cond) {
1051+
for (int i = 0; i < ne_elements; i++) {
1052+
float delta;
1053+
if (has_img_cond) {
1054+
if (cfg_scale == 1) {
1055+
// Weird guidance (important: use img_cfg_scale instead of cfg_scale in the final formula)
1056+
delta = img_cond_data[i] - negative_data[i];
1057+
} else if (has_unconditioned) {
1058+
// 2-conditioning CFG (img_cfg_scale != cfg_scale != 1)
1059+
delta = positive_data[i] + (negative_data[i] * (1 - img_cfg_scale) + img_cond_data[i] * (img_cfg_scale - cfg_scale)) / (cfg_scale - 1);
1060+
} else {
1061+
// pure img CFG (img_cfg_scale == 1, cfg_scale !=1)
1062+
delta = positive_data[i] - img_cond_data[i];
1063+
}
1064+
} else {
1065+
// classic CFG (img_cfg_scale == cfg_scale != 1)
1066+
delta = positive_data[i] - negative_data[i];
1067+
}
1068+
deltas[i] = delta;
1069+
}
1070+
if (apg_params.norm_treshold > 0) {
1071+
diff_norm = sqrtf(diff_norm);
1072+
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
1073+
}
1074+
if (apg_params.eta != 1.0f) {
1075+
dot *= apg_scale_factor;
1076+
// pre-normalize (avoids one square root and ne_elements extra divs)
1077+
dot /= cond_norm_sq;
1078+
}
1079+
1080+
for (int i = 0; i < ne_elements; i++) {
1081+
deltas[i] *= apg_scale_factor;
1082+
if (apg_params.eta != 1.0f) {
1083+
float apg_parallel = dot * positive_data[i];
1084+
float apg_orthogonal = deltas[i] - apg_parallel;
1085+
1086+
// tweak deltas
1087+
deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
1088+
}
1089+
}
1090+
}
1091+
10371092
for (int i = 0; i < ne_elements; i++) {
10381093
float latent_result = positive_data[i];
10391094
if (has_unconditioned) {
@@ -1043,12 +1098,12 @@ class StableDiffusionGGML {
10431098
int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2];
10441099
float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3);
10451100
} else {
1046-
if (has_img_cond) {
1047-
// out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
1048-
latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]);
1049-
} else {
1050-
// img_cfg_scale == cfg_scale
1051-
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
1101+
float delta = deltas[i];
1102+
1103+
if(cfg_scale != 1) {
1104+
latent_result = positive_data[i] + (cfg_scale - 1) * delta;
1105+
} else if (has_img_cond) {
1106+
latent_result = positive_data[i] + (img_cfg_scale - 1) * delta;
10521107
}
10531108
}
10541109
} else if (has_img_cond) {
@@ -1096,7 +1151,8 @@ class StableDiffusionGGML {
10961151
}
10971152

10981153
// ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding
1099-
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* moments) {
1154+
ggml_tensor*
1155+
get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* moments) {
11001156
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
11011157
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
11021158
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
@@ -1529,6 +1585,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
15291585
std::string input_id_images_path,
15301586
std::vector<ggml_tensor*> ref_latents,
15311587
ggml_tensor* concat_latent = NULL,
1588+
sd_apg_params_t apg_params = {},
15321589
ggml_tensor* denoise_mask = NULL) {
15331590
if (seed < 0) {
15341591
// Generally, when using the provided command line, the seed is always >0.
@@ -1798,6 +1855,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17981855
start_merge_step,
17991856
id_cond,
18001857
ref_latents,
1858+
apg_params,
18011859
denoise_mask);
18021860

18031861
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
@@ -1872,7 +1930,7 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
18721930
return init_latent;
18731931
}
18741932

1875-
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
1933+
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params, sd_apg_params_t apg_params) {
18761934
int width = sd_img_gen_params->width;
18771935
int height = sd_img_gen_params->height;
18781936
LOG_DEBUG("generate_image %dx%d", width, height);
@@ -2072,6 +2130,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20722130
sd_img_gen_params->input_id_images_path,
20732131
ref_latents,
20742132
concat_latent,
2133+
apg_params,
20752134
denoise_mask);
20762135

20772136
size_t t2 = ggml_time_ms();

stable-diffusion.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ typedef struct {
154154
float scale;
155155
} sd_slg_params_t;
156156

157+
typedef struct {
158+
float eta;
159+
float momentum;
160+
float norm_treshold;
161+
} sd_apg_params_t;
162+
157163
typedef struct {
158164
float txt_cfg;
159165
float img_cfg;
@@ -228,7 +234,7 @@ SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
228234

229235
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
230236
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
231-
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
237+
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params, sd_apg_params_t apg_params);
232238

233239
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
234240
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params); // broken

0 commit comments

Comments
 (0)