@@ -40,6 +40,8 @@ const char* sample_method_str[] = {
4040 " dpm++2s_a" ,
4141 " dpm++2m" ,
4242 " dpm++2mv2" ,
43+ " ipndm" ,
44+ " ipndm_v" ,
4345 " lcm" ,
4446};
4547
@@ -48,7 +50,9 @@ const char* schedule_str[] = {
4850 " default" ,
4951 " discrete" ,
5052 " karras" ,
53+ " exponential" ,
5154 " ays" ,
55+ " gits" ,
5256};
5357
5458enum SDMode {
@@ -676,45 +680,45 @@ static void log_server_request(const httplib::Request& req, const httplib::Respo
676680 printf (" request: %s %s (%s)\n " , req.method .c_str (), req.path .c_str (), req.body .c_str ());
677681}
678682
679- void parseJsonPrompt (std::string json_str, SDParams* params) {
683+ void parseJsonPrompt (std::string json_str, SDParams& params) {
680684 using namespace nlohmann ;
681685 json payload = json::parse (json_str);
682686 // if no exception, the request is a json object
683687 // now we try to get the new param values from the payload object
684688 // const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
685689 try {
686690 std::string prompt = payload[" prompt" ];
687- params-> prompt = prompt;
691+ params. prompt = prompt;
688692 } catch (...) {
689693 }
690694 try {
691695 std::string negative_prompt = payload[" negative_prompt" ];
692- params-> negative_prompt = negative_prompt;
696+ params. negative_prompt = negative_prompt;
693697 } catch (...) {
694698 }
695699 try {
696- int clip_skip = payload[" clip_skip" ];
697- params-> clip_skip = clip_skip;
700+ int clip_skip = payload[" clip_skip" ];
701+ params. clip_skip = clip_skip;
698702 } catch (...) {
699703 }
700704 try {
701- float cfg_scale = payload[" cfg_scale" ];
702- params-> cfg_scale = cfg_scale;
705+ float cfg_scale = payload[" cfg_scale" ];
706+ params. cfg_scale = cfg_scale;
703707 } catch (...) {
704708 }
705709 try {
706- float guidance = payload[" guidance" ];
707- params-> guidance = guidance;
710+ float guidance = payload[" guidance" ];
711+ params. guidance = guidance;
708712 } catch (...) {
709713 }
710714 try {
711- int width = payload[" width" ];
712- params-> width = width;
715+ int width = payload[" width" ];
716+ params. width = width;
713717 } catch (...) {
714718 }
715719 try {
716- int height = payload[" height" ];
717- params-> height = height;
720+ int height = payload[" height" ];
721+ params. height = height;
718722 } catch (...) {
719723 }
720724 try {
@@ -727,25 +731,25 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
727731 }
728732 }
729733 if (sample_method_found >= 0 ) {
730- params-> sample_method = (sample_method_t )sample_method_found;
734+ params. sample_method = (sample_method_t )sample_method_found;
731735 } else {
732736 sd_log (sd_log_level_t ::SD_LOG_WARN, " Unknown sampling method: %s\n " , sample_method.c_str ());
733737 }
734738 } catch (...) {
735739 }
736740 try {
737- int sample_steps = payload[" sample_steps" ];
738- params-> sample_steps = sample_steps;
741+ int sample_steps = payload[" sample_steps" ];
742+ params. sample_steps = sample_steps;
739743 } catch (...) {
740744 }
741745 try {
742746 int64_t seed = payload[" seed" ];
743- params-> seed = seed;
747+ params. seed = seed;
744748 } catch (...) {
745749 }
746750 try {
747- int batch_count = payload[" batch_count" ];
748- params-> batch_count = batch_count;
751+ int batch_count = payload[" batch_count" ];
752+ params. batch_count = batch_count;
749753 } catch (...) {
750754 }
751755
@@ -759,57 +763,135 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
759763 }
760764 try {
761765 float control_strength = payload[" control_strength" ];
762- // params-> control_strength = control_strength;
766+ // params. control_strength = control_strength;
763767 // LOG_WARN("control_strength is not supported yet\n");
764768 sd_log (sd_log_level_t ::SD_LOG_WARN, " control_strength is not supported yet\n " , params);
765769 } catch (...) {
766770 }
767771 try {
768772 float style_strength = payload[" style_strength" ];
769- // params-> style_strength = style_strength;
773+ // params. style_strength = style_strength;
770774 // LOG_WARN("style_strength is not supported yet\n");
771775 sd_log (sd_log_level_t ::SD_LOG_WARN, " style_strength is not supported yet\n " , params);
772776 } catch (...) {
773777 }
774778 try {
775- bool normalize_input = payload[" normalize_input" ];
776- params-> normalize_input = normalize_input;
779+ bool normalize_input = payload[" normalize_input" ];
780+ params. normalize_input = normalize_input;
777781 } catch (...) {
778782 }
779783 try {
780784 std::string input_id_images_path = payload[" input_id_images_path" ];
781785 // TODO replace with b64 image maybe?
782- params-> input_id_images_path = input_id_images_path;
786+ params. input_id_images_path = input_id_images_path;
783787 } catch (...) {
784788 }
785789 try {
786790 std::string slg_scale = payload[" slg_scale" ];
787- params-> slg_scale = stof (slg_scale);
791+ params. slg_scale = stof (slg_scale);
788792 } catch (...) {
789793 }
790794 // TODO: more slg settings (layers, start and end)
791795 try {
792796 std::vector<int > skip_layers = payload[" skip_layers" ].get <std::vector<int >>();
793- params-> skip_layers .clear ();
797+ params. skip_layers .clear ();
794798 for (int i = 0 ; i < skip_layers.size (); i++) {
795- params-> skip_layers .push_back (skip_layers[i]);
799+ params. skip_layers .push_back (skip_layers[i]);
796800 }
797801 } catch (...) {
798802 }
799803 try {
800804 // skip_layer_start
801- float skip_layer_start = payload[" skip_layer_start" ].get <float >();
802- params-> skip_layer_start = skip_layer_start;
805+ float skip_layer_start = payload[" skip_layer_start" ].get <float >();
806+ params. skip_layer_start = skip_layer_start;
803807 } catch (...) {
804808 }
805809 try {
806810 // skip_layer_end
807- float skip_layer_end = payload[" skip_layer_end" ].get <float >();
808- params-> skip_layer_end = skip_layer_end;
811+ float skip_layer_end = payload[" skip_layer_end" ].get <float >();
812+ params. skip_layer_end = skip_layer_end;
809813 } catch (...) {
810814 }
811815}
812816
817+ struct SDAPIParams {
818+ std::string prompt = " " ;
819+ std::string negative_prompt = " " ;
820+ // std::vector<std::string> styles = {};
821+ int seed = -1 ;
822+ // int subseed = -1;
823+ // float subseed_strength = 0.;
824+ // int seed_resize_from_h = -1;
825+ // int seed_resize_from_w = -1;
826+ std::string sampler_name = " " ;
827+ std::string scheduler = " " ;
828+ // int batch_size = 1; //batch processing
829+ int n_iter = 1 ; // batch_count
830+ int steps = 50 ;
831+ float cfg_scale = 7 ;
832+ int width = 512 ;
833+ int height = 512 ;
834+ // bool restore_faces = false;
835+ // bool tiling = false;
836+ // bool do_not_save_samples = false;
837+ // bool do_not_save_grid = false;
838+ // float eta = 0; //for ddim
839+ float denoising_strength = 0.75 ;
840+ // float s_min_uncond = 0;
841+ // float s_churn = 0;
842+ // float s_tmax = 0;
843+ // float s_tmin = 0;
844+ // float s_noise = 0;
845+ // nlohmann::json override_settings = {};
846+ // bool override_settings_restore_afterwards = true;
847+ // std::string refiner_checkpoint = "";
848+ // int refiner_switch_at = 0;
849+ // bool disable_extra_networks = false;
850+ // std::string firstpass_image = ""; //for highres_fix upscaling
851+ // nlohmann::json comments = {};
852+
853+ // // Highres_fix stuff
854+ // bool enable_hr = false;
855+ // int firstphase_width = 0;
856+ // int firstphase_height = 0;
857+ // int hr_scale = 2;
858+ // std::string hr_upscaler = "";
859+ // int hr_second_pass_steps = 0;
860+ // int hr_resize_x = 0;
861+ // int hr_resize_y = 0;
862+ // std::string hr_checkpoint_name = "";
863+ // std::string hr_sampler_name = "";
864+ // std::string hr_scheduler = "";
865+ // std::string hr_prompt = "";
866+ // std::string hr_negative_prompt = "";
867+
868+ // // img2img stuff
869+ std::vector<std::string> init_images = {};
870+ int resize_mode = 0 ;
871+ float image_cfg_scale = 0 ;
872+ std::string mask = " " ;
873+ int mask_blur_x = 4 ;
874+ int mask_blur_y = 4 ;
875+ int mask_blur = 0 ;
876+ bool mask_round = true ;
877+ int inpainting_fill = 0 ;
878+ bool inpaint_full_res = true ;
879+ int inpaint_full_res_padding = 0 ;
880+ int inpainting_mask_invert = 0 ;
881+ float initial_noise_multiplier = 0 ;
882+ std::string latent_mask = " " ;
883+
884+ // std::string force_task_id = "";
885+ std::string sampler_index = " Euler" ;
886+ // bool include_init_images = false;
887+ // std::string script_name = "";
888+ // nlohmann::json script_args = {};
889+ bool send_images = true ;
890+ bool save_images = false ;
891+ // nlohmann::json alwayson_scripts = {};
892+ // std::string infotext = "";
893+ };
894+
813895// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
814896const float flux_latent_rgb_proj[16 ][3 ] = {
815897 {-0.0346 , 0.0244 , 0.0681 },
@@ -863,7 +945,7 @@ const float sd_latent_rgb_proj[4][3]{
863945 {-0.2829 , 0.1762 , 0.2721 },
864946 {-0.2120 , -0.2616 , -0.7177 }};
865947
866- void step_callback ( int step, struct ggml_tensor * latents, enum SDVersion version) {
948+ void proj_latents ( struct ggml_tensor * latents, enum SDVersion version, uint8_t * data ) {
867949 const int channel = 3 ;
868950 int width = latents->ne [0 ];
869951 int height = latents->ne [1 ];
@@ -876,7 +958,7 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
876958
877959 if (version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_2B) {
878960 latent_rgb_proj = sd3_latent_rgb_proj;
879- } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
961+ } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE ) {
880962 latent_rgb_proj = flux_latent_rgb_proj;
881963 } else {
882964 // unknown model
@@ -897,7 +979,6 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
897979 // unknown latent space
898980 return ;
899981 }
900- uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
901982 int data_head = 0 ;
902983 for (int j = 0 ; j < height; j++) {
903984 for (int i = 0 ; i < width; i++) {
@@ -925,6 +1006,15 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
9251006 data[data_head++] = (uint8_t )(b * 255 .);
9261007 }
9271008 }
1009+ }
1010+
1011+ void step_callback (int step, struct ggml_tensor * latents, enum SDVersion version) {
1012+ const int channel = 3 ;
1013+ int width = latents->ne [0 ];
1014+ int height = latents->ne [1 ];
1015+ int dim = latents->ne [2 ];
1016+ uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
1017+ proj_latents (latents, version, data);
9281018 stbi_write_png (" latent-preview.png" , width, height, channel, data, 0 );
9291019 free (data);
9301020}
@@ -982,7 +1072,7 @@ int main(int argc, const char* argv[]) {
9821072
9831073 try {
9841074 std::string json_str = req.body ;
985- parseJsonPrompt (json_str, & params);
1075+ parseJsonPrompt (json_str, params);
9861076 } catch (json::parse_error& e) {
9871077 // assume the request is just a prompt
9881078 // LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what());
0 commit comments