@@ -40,6 +40,8 @@ const char* sample_method_str[] = {
4040 " dpm++2s_a" ,
4141 " dpm++2m" ,
4242 " dpm++2mv2" ,
43+ " ipndm" ,
44+ " ipndm_v" ,
4345 " lcm" ,
4446};
4547
@@ -48,7 +50,9 @@ const char* schedule_str[] = {
4850 " default" ,
4951 " discrete" ,
5052 " karras" ,
53+ " exponential" ,
5154 " ays" ,
55+ " gits" ,
5256};
5357
5458enum SDMode {
@@ -676,45 +680,45 @@ static void log_server_request(const httplib::Request& req, const httplib::Respo
676680 printf (" request: %s %s (%s)\n " , req.method .c_str (), req.path .c_str (), req.body .c_str ());
677681}
678682
679- void parseJsonPrompt (std::string json_str, SDParams* params) {
683+ void parseJsonPrompt (std::string json_str, SDParams& params) {
680684 using namespace nlohmann ;
681685 json payload = json::parse (json_str);
682686 // if no exception, the request is a json object
683687 // now we try to get the new param values from the payload object
684688 // const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
685689 try {
686690 std::string prompt = payload[" prompt" ];
687- params-> prompt = prompt;
691+ params. prompt = prompt;
688692 } catch (...) {
689693 }
690694 try {
691695 std::string negative_prompt = payload[" negative_prompt" ];
692- params-> negative_prompt = negative_prompt;
696+ params. negative_prompt = negative_prompt;
693697 } catch (...) {
694698 }
695699 try {
696- int clip_skip = payload[" clip_skip" ];
697- params-> clip_skip = clip_skip;
700+ int clip_skip = payload[" clip_skip" ];
701+ params. clip_skip = clip_skip;
698702 } catch (...) {
699703 }
700704 try {
701- float cfg_scale = payload[" cfg_scale" ];
702- params-> cfg_scale = cfg_scale;
705+ float cfg_scale = payload[" cfg_scale" ];
706+ params. cfg_scale = cfg_scale;
703707 } catch (...) {
704708 }
705709 try {
706- float guidance = payload[" guidance" ];
707- params-> guidance = guidance;
710+ float guidance = payload[" guidance" ];
711+ params. guidance = guidance;
708712 } catch (...) {
709713 }
710714 try {
711- int width = payload[" width" ];
712- params-> width = width;
715+ int width = payload[" width" ];
716+ params. width = width;
713717 } catch (...) {
714718 }
715719 try {
716- int height = payload[" height" ];
717- params-> height = height;
720+ int height = payload[" height" ];
721+ params. height = height;
718722 } catch (...) {
719723 }
720724 try {
@@ -727,25 +731,25 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
727731 }
728732 }
729733 if (sample_method_found >= 0 ) {
730- params-> sample_method = (sample_method_t )sample_method_found;
734+ params. sample_method = (sample_method_t )sample_method_found;
731735 } else {
732736 sd_log (sd_log_level_t ::SD_LOG_WARN, " Unknown sampling method: %s\n " , sample_method.c_str ());
733737 }
734738 } catch (...) {
735739 }
736740 try {
737- int sample_steps = payload[" sample_steps" ];
738- params-> sample_steps = sample_steps;
741+ int sample_steps = payload[" sample_steps" ];
742+ params. sample_steps = sample_steps;
739743 } catch (...) {
740744 }
741745 try {
742746 int64_t seed = payload[" seed" ];
743- params-> seed = seed;
747+ params. seed = seed;
744748 } catch (...) {
745749 }
746750 try {
747- int batch_count = payload[" batch_count" ];
748- params-> batch_count = batch_count;
751+ int batch_count = payload[" batch_count" ];
752+ params. batch_count = batch_count;
749753 } catch (...) {
750754 }
751755
@@ -759,53 +763,53 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
759763 }
760764 try {
761765 float control_strength = payload[" control_strength" ];
762- // params-> control_strength = control_strength;
766+ // params. control_strength = control_strength;
763767 // LOG_WARN("control_strength is not supported yet\n");
764768 sd_log (sd_log_level_t ::SD_LOG_WARN, " control_strength is not supported yet\n " , params);
765769 } catch (...) {
766770 }
767771 try {
768772 float style_strength = payload[" style_strength" ];
769- // params-> style_strength = style_strength;
773+ // params. style_strength = style_strength;
770774 // LOG_WARN("style_strength is not supported yet\n");
771775 sd_log (sd_log_level_t ::SD_LOG_WARN, " style_strength is not supported yet\n " , params);
772776 } catch (...) {
773777 }
774778 try {
775- bool normalize_input = payload[" normalize_input" ];
776- params-> normalize_input = normalize_input;
779+ bool normalize_input = payload[" normalize_input" ];
780+ params. normalize_input = normalize_input;
777781 } catch (...) {
778782 }
779783 try {
780784 std::string input_id_images_path = payload[" input_id_images_path" ];
781785 // TODO replace with b64 image maybe?
782- params-> input_id_images_path = input_id_images_path;
786+ params. input_id_images_path = input_id_images_path;
783787 } catch (...) {
784788 }
785789 try {
786790 std::string slg_scale = payload[" slg_scale" ];
787- params-> slg_scale = stof (slg_scale);
791+ params. slg_scale = stof (slg_scale);
788792 } catch (...) {
789793 }
790794 // TODO: more slg settings (layers, start and end)
791795 try {
792796 std::vector<int > skip_layers = payload[" skip_layers" ].get <std::vector<int >>();
793- params-> skip_layers .clear ();
797+ params. skip_layers .clear ();
794798 for (int i = 0 ; i < skip_layers.size (); i++) {
795- params-> skip_layers .push_back (skip_layers[i]);
799+ params. skip_layers .push_back (skip_layers[i]);
796800 }
797801 } catch (...) {
798802 }
799803 try {
800804 // skip_layer_start
801- float skip_layer_start = payload[" skip_layer_start" ].get <float >();
802- params-> skip_layer_start = skip_layer_start;
805+ float skip_layer_start = payload[" skip_layer_start" ].get <float >();
806+ params. skip_layer_start = skip_layer_start;
803807 } catch (...) {
804808 }
805809 try {
806810 // skip_layer_end
807- float skip_layer_end = payload[" skip_layer_end" ].get <float >();
808- params-> skip_layer_end = skip_layer_end;
811+ float skip_layer_end = payload[" skip_layer_end" ].get <float >();
812+ params. skip_layer_end = skip_layer_end;
809813 } catch (...) {
810814 }
811815}
@@ -863,7 +867,7 @@ const float sd_latent_rgb_proj[4][3]{
863867 {-0.2829 , 0.1762 , 0.2721 },
864868 {-0.2120 , -0.2616 , -0.7177 }};
865869
866- void step_callback ( int step, struct ggml_tensor * latents, enum SDVersion version) {
870+ void proj_latents ( struct ggml_tensor * latents, enum SDVersion version, uint8_t * data ) {
867871 const int channel = 3 ;
868872 int width = latents->ne [0 ];
869873 int height = latents->ne [1 ];
@@ -876,7 +880,7 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
876880
877881 if (version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_2B) {
878882 latent_rgb_proj = sd3_latent_rgb_proj;
879- } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
883+ } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE ) {
880884 latent_rgb_proj = flux_latent_rgb_proj;
881885 } else {
882886 // unknown model
@@ -897,7 +901,6 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
897901 // unknown latent space
898902 return ;
899903 }
900- uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
901904 int data_head = 0 ;
902905 for (int j = 0 ; j < height; j++) {
903906 for (int i = 0 ; i < width; i++) {
@@ -925,6 +928,15 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
925928 data[data_head++] = (uint8_t )(b * 255 .);
926929 }
927930 }
931+ }
932+
933+ void step_callback (int step, struct ggml_tensor * latents, enum SDVersion version) {
934+ const int channel = 3 ;
935+ int width = latents->ne [0 ];
936+ int height = latents->ne [1 ];
937+ int dim = latents->ne [2 ];
938+ uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
939+ proj_latents (latents, version, data);
928940 stbi_write_png (" latent-preview.png" , width, height, channel, data, 0 );
929941 free (data);
930942}
@@ -982,7 +994,7 @@ int main(int argc, const char* argv[]) {
982994
983995 try {
984996 std::string json_str = req.body ;
985- parseJsonPrompt (json_str, & params);
997+ parseJsonPrompt (json_str, params);
986998 } catch (json::parse_error& e) {
987999 // assume the request is just a prompt
9881000 // LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what());
0 commit comments