@@ -1062,7 +1062,18 @@ class StableDiffusionGGML {
10621062
10631063 float * deltas = vec_denoised;
10641064
1065- // https://arxiv.org/pdf/2410.02416
1065+ // APG: https://arxiv.org/pdf/2410.02416
1066+
1067+ bool log_cfg_norm = false ;
1068+ const char * SD_LOG_CFG_DELTA_NORM = getenv (" SD_LOG_CFG_DELTA_NORM" );
1069+ if (SD_LOG_CFG_DELTA_NORM != nullptr ) {
1070+ std::string sd_log_cfg_norm_str = SD_LOG_CFG_DELTA_NORM;
1071+ if (sd_log_cfg_norm_str == " ON" || sd_log_cfg_norm_str == " TRUE" ) {
1072+ log_cfg_norm = true ;
1073+ } else if (sd_log_cfg_norm_str != " OFF" && sd_log_cfg_norm_str != " FALSE" ) {
1074+ LOG_WARN (" SD_LOG_CFG_DELTA_NORM environment variable has unexpected value. Assuming default (\" OFF\" ). (Expected \" ON\" /\" TRUE\" or\" OFF\" /\" FALSE\" , got \" %s\" )" , SD_LOG_CFG_DELTA_NORM);
1075+ }
1076+ }
10661077 float apg_scale_factor = 1 .;
10671078 float diff_norm = 0 ;
10681079 float cond_norm_sq = 0 ;
@@ -1085,8 +1096,22 @@ class StableDiffusionGGML {
10851096 // classic CFG (img_cfg_scale == cfg_scale != 1)
10861097 delta = positive_data[i] - negative_data[i];
10871098 }
1099+ if (apg_params.momentum != 0 ) {
1100+ delta += apg_params.momentum * apg_momentum_buffer[i];
1101+ apg_momentum_buffer[i] = delta;
1102+ }
1103+ if (apg_params.norm_treshold > 0 || log_cfg_norm) {
1104+ diff_norm += delta * delta;
1105+ }
1106+ if (apg_params.eta != 1 .0f ) {
1107+ cond_norm_sq += positive_data[i] * positive_data[i];
1108+ dot += positive_data[i] * delta;
1109+ }
10881110 deltas[i] = delta;
10891111 }
1112+ if (log_cfg_norm) {
1113+ LOG_INFO (" CFG Delta norm: %.2f" , sqrtf (diff_norm));
1114+ }
10901115 if (apg_params.norm_treshold > 0 ) {
10911116 diff_norm = sqrtf (diff_norm);
10921117 if (apg_params.norm_treshold_smoothing <= 0 ) {
0 commit comments