@@ -75,7 +75,7 @@ struct SDParams {
7575 std::string stacked_id_embeddings_path;
7676 std::string lora_model_dir;
7777
78- sd_type_t wtype = SD_TYPE_COUNT;
78+ sd_type_t wtype = SD_TYPE_COUNT;
7979 std::string output_path = " output.png" ;
8080 std::string input_path;
8181
@@ -707,6 +707,125 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
707707 }
708708}
709709
710+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
711+ const float flux_latent_rgb_proj[16 ][3 ] = {
712+ {-0.0346 , 0.0244 , 0.0681 },
713+ {0.0034 , 0.0210 , 0.0687 },
714+ {0.0275 , -0.0668 , -0.0433 },
715+ {-0.0174 , 0.0160 , 0.0617 },
716+ {0.0859 , 0.0721 , 0.0329 },
717+ {0.0004 , 0.0383 , 0.0115 },
718+ {0.0405 , 0.0861 , 0.0915 },
719+ {-0.0236 , -0.0185 , -0.0259 },
720+ {-0.0245 , 0.0250 , 0.1180 },
721+ {0.1008 , 0.0755 , -0.0421 },
722+ {-0.0515 , 0.0201 , 0.0011 },
723+ {0.0428 , -0.0012 , -0.0036 },
724+ {0.0817 , 0.0765 , 0.0749 },
725+ {-0.1264 , -0.0522 , -0.1103 },
726+ {-0.0280 , -0.0881 , -0.0499 },
727+ {-0.1262 , -0.0982 , -0.0778 }};
728+
729+ // https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L228-L246
730+ const float sd3_latent_rgb_proj[16 ][3 ] = {
731+ {-0.0645 , 0.0177 , 0.1052 },
732+ {0.0028 , 0.0312 , 0.0650 },
733+ {0.1848 , 0.0762 , 0.0360 },
734+ {0.0944 , 0.0360 , 0.0889 },
735+ {0.0897 , 0.0506 , -0.0364 },
736+ {-0.0020 , 0.1203 , 0.0284 },
737+ {0.0855 , 0.0118 , 0.0283 },
738+ {-0.0539 , 0.0658 , 0.1047 },
739+ {-0.0057 , 0.0116 , 0.0700 },
740+ {-0.0412 , 0.0281 , -0.0039 },
741+ {0.1106 , 0.1171 , 0.1220 },
742+ {-0.0248 , 0.0682 , -0.0481 },
743+ {0.0815 , 0.0846 , 0.1207 },
744+ {-0.0120 , -0.0055 , -0.0867 },
745+ {-0.0749 , -0.0634 , -0.0456 },
746+ {-0.1418 , -0.1457 , -0.1259 },
747+ };
748+
749+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
750+ const float sdxl_latent_rgb_proj[4 ][3 ] = {
751+ {0.3651 , 0.4232 , 0.4341 },
752+ {-0.2533 , -0.0042 , 0.1068 },
753+ {0.1076 , 0.1111 , -0.0362 },
754+ {-0.3165 , -0.2492 , -0.2188 }};
755+
756+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
757+ const float sd_latent_rgb_proj[4 ][3 ]{
758+ {0.3512 , 0.2297 , 0.3227 },
759+ {0.3250 , 0.4974 , 0.2350 },
760+ {-0.2829 , 0.1762 , 0.2721 },
761+ {-0.2120 , -0.2616 , -0.7177 }};
762+
763+ void step_callback (int step, struct ggml_tensor * latents, enum SDVersion version) {
764+ const int channel = 3 ;
765+ int width = latents->ne [0 ];
766+ int height = latents->ne [1 ];
767+ int dim = latents->ne [2 ];
768+
769+ const float (*latent_rgb_proj)[channel];
770+
771+ if (dim == 16 ) {
772+ // 16 channels VAE -> Flux or SD3
773+
774+ if (version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_2B) {
775+ latent_rgb_proj = sd3_latent_rgb_proj;
776+ } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
777+ latent_rgb_proj = flux_latent_rgb_proj;
778+ } else {
779+ // unknown model
780+ return ;
781+ }
782+
783+ } else if (dim == 4 ) {
784+ // 4 channels VAE
785+ if (version == VERSION_SDXL) {
786+ latent_rgb_proj = sdxl_latent_rgb_proj;
787+ } else if (version == VERSION_SD1 || version == VERSION_SD2) {
788+ latent_rgb_proj = sd_latent_rgb_proj;
789+ } else {
790+ // unknown model
791+ return ;
792+ }
793+ } else {
794+ // unknown latent space
795+ return ;
796+ }
797+ uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
798+ int data_head = 0 ;
799+ for (int j = 0 ; j < height; j++) {
800+ for (int i = 0 ; i < width; i++) {
801+ int latent_id = (i * latents->nb [0 ] + j * latents->nb [1 ]);
802+ float r = 0 , g = 0 , b = 0 ;
803+ for (int d = 0 ; d < dim; d++) {
804+ float value = *(float *)((char *)latents->data + latent_id + d * latents->nb [2 ]);
805+ r += value * latent_rgb_proj[d][0 ];
806+ g += value * latent_rgb_proj[d][1 ];
807+ b += value * latent_rgb_proj[d][2 ];
808+ }
809+
810+ // change range
811+ r = r * .5 + .5 ;
812+ g = g * .5 + .5 ;
813+ b = b * .5 + .5 ;
814+
815+ // clamp rgb values to [0,1] range
816+ r = r >= 0 ? r <= 1 ? r : 1 : 0 ;
817+ g = g >= 0 ? g <= 1 ? g : 1 : 0 ;
818+ b = b >= 0 ? b <= 1 ? b : 1 : 0 ;
819+
820+ data[data_head++] = (uint8_t )(r * 255 .);
821+ data[data_head++] = (uint8_t )(g * 255 .);
822+ data[data_head++] = (uint8_t )(b * 255 .);
823+ }
824+ }
825+ stbi_write_png (" latent-preview.png" , width, height, channel, data, 0 );
826+ free (data);
827+ }
828+
710829int main (int argc, const char * argv[]) {
711830 SDParams params;
712831
@@ -797,7 +916,8 @@ int main(int argc, const char* argv[]) {
797916 1 ,
798917 params.style_ratio ,
799918 params.normalize_input ,
800- params.input_id_images_path .c_str ());
919+ params.input_id_images_path .c_str (),
920+ *step_callback);
801921
802922 if (results == NULL ) {
803923 printf (" generate failed\n " );
0 commit comments