@@ -765,6 +765,125 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
765765 fflush (out_stream);
766766}
767767
768+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
769+ const float flux_latent_rgb_proj[16 ][3 ] = {
770+ {-0.0346 , 0.0244 , 0.0681 },
771+ {0.0034 , 0.0210 , 0.0687 },
772+ {0.0275 , -0.0668 , -0.0433 },
773+ {-0.0174 , 0.0160 , 0.0617 },
774+ {0.0859 , 0.0721 , 0.0329 },
775+ {0.0004 , 0.0383 , 0.0115 },
776+ {0.0405 , 0.0861 , 0.0915 },
777+ {-0.0236 , -0.0185 , -0.0259 },
778+ {-0.0245 , 0.0250 , 0.1180 },
779+ {0.1008 , 0.0755 , -0.0421 },
780+ {-0.0515 , 0.0201 , 0.0011 },
781+ {0.0428 , -0.0012 , -0.0036 },
782+ {0.0817 , 0.0765 , 0.0749 },
783+ {-0.1264 , -0.0522 , -0.1103 },
784+ {-0.0280 , -0.0881 , -0.0499 },
785+ {-0.1262 , -0.0982 , -0.0778 }};
786+
787+ // https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L228-L246
788+ const float sd3_latent_rgb_proj[16 ][3 ] = {
789+ {-0.0645 , 0.0177 , 0.1052 },
790+ {0.0028 , 0.0312 , 0.0650 },
791+ {0.1848 , 0.0762 , 0.0360 },
792+ {0.0944 , 0.0360 , 0.0889 },
793+ {0.0897 , 0.0506 , -0.0364 },
794+ {-0.0020 , 0.1203 , 0.0284 },
795+ {0.0855 , 0.0118 , 0.0283 },
796+ {-0.0539 , 0.0658 , 0.1047 },
797+ {-0.0057 , 0.0116 , 0.0700 },
798+ {-0.0412 , 0.0281 , -0.0039 },
799+ {0.1106 , 0.1171 , 0.1220 },
800+ {-0.0248 , 0.0682 , -0.0481 },
801+ {0.0815 , 0.0846 , 0.1207 },
802+ {-0.0120 , -0.0055 , -0.0867 },
803+ {-0.0749 , -0.0634 , -0.0456 },
804+ {-0.1418 , -0.1457 , -0.1259 },
805+ };
806+
807+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
808+ const float sdxl_latent_rgb_proj[4 ][3 ] = {
809+ {0.3651 , 0.4232 , 0.4341 },
810+ {-0.2533 , -0.0042 , 0.1068 },
811+ {0.1076 , 0.1111 , -0.0362 },
812+ {-0.3165 , -0.2492 , -0.2188 }};
813+
814+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
815+ const float sd_latent_rgb_proj[4 ][3 ]{
816+ {0.3512 , 0.2297 , 0.3227 },
817+ {0.3250 , 0.4974 , 0.2350 },
818+ {-0.2829 , 0.1762 , 0.2721 },
819+ {-0.2120 , -0.2616 , -0.7177 }};
820+
821+ void step_callback (int step, struct ggml_tensor * latents, enum SDVersion version) {
822+ const int channel = 3 ;
823+ int width = latents->ne [0 ];
824+ int height = latents->ne [1 ];
825+ int dim = latents->ne [2 ];
826+
827+ const float (*latent_rgb_proj)[channel];
828+
829+ if (dim == 16 ) {
830+ // 16 channels VAE -> Flux or SD3
831+
832+ if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B /* || version == VERSION_SD3_5_2B*/ ) {
833+ latent_rgb_proj = sd3_latent_rgb_proj;
834+ } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
835+ latent_rgb_proj = flux_latent_rgb_proj;
836+ } else {
837+ // unknown model
838+ return ;
839+ }
840+
841+ } else if (dim == 4 ) {
842+ // 4 channels VAE
843+ if (version == VERSION_SDXL) {
844+ latent_rgb_proj = sdxl_latent_rgb_proj;
845+ } else if (version == VERSION_SD1 || version == VERSION_SD2) {
846+ latent_rgb_proj = sd_latent_rgb_proj;
847+ } else {
848+ // unknown model
849+ return ;
850+ }
851+ } else {
852+ // unknown latent space
853+ return ;
854+ }
855+ uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
856+ int data_head = 0 ;
857+ for (int j = 0 ; j < height; j++) {
858+ for (int i = 0 ; i < width; i++) {
859+ int latent_id = (i * latents->nb [0 ] + j * latents->nb [1 ]);
860+ float r = 0 , g = 0 , b = 0 ;
861+ for (int d = 0 ; d < dim; d++) {
862+ float value = *(float *)((char *)latents->data + latent_id + d * latents->nb [2 ]);
863+ r += value * latent_rgb_proj[d][0 ];
864+ g += value * latent_rgb_proj[d][1 ];
865+ b += value * latent_rgb_proj[d][2 ];
866+ }
867+
868+ // change range
869+ r = r * .5 + .5 ;
870+ g = g * .5 + .5 ;
871+ b = b * .5 + .5 ;
872+
873+ // clamp rgb values to [0,1] range
874+ r = r >= 0 ? r <= 1 ? r : 1 : 0 ;
875+ g = g >= 0 ? g <= 1 ? g : 1 : 0 ;
876+ b = b >= 0 ? b <= 1 ? b : 1 : 0 ;
877+
878+ data[data_head++] = (uint8_t )(r * 255 .);
879+ data[data_head++] = (uint8_t )(g * 255 .);
880+ data[data_head++] = (uint8_t )(b * 255 .);
881+ }
882+ }
883+ stbi_write_png (" latent-preview.png" , width, height, channel, data, 0 );
884+ free (data);
885+ }
886+
768887int main (int argc, const char * argv[]) {
769888 SDParams params;
770889
@@ -930,7 +1049,8 @@ int main(int argc, const char* argv[]) {
9301049 params.skip_layers .size (),
9311050 params.slg_scale ,
9321051 params.skip_layer_start ,
933- params.skip_layer_end );
1052+ params.skip_layer_end ,
1053+ step_callback);
9341054 } else {
9351055 sd_image_t input_image = {(uint32_t )params.width ,
9361056 (uint32_t )params.height ,
0 commit comments