@@ -712,7 +712,8 @@ namespace Flux {
712712 struct ggml_tensor * timesteps,
713713 struct ggml_tensor * y,
714714 struct ggml_tensor * guidance,
715- struct ggml_tensor * pe) {
715+ struct ggml_tensor * pe,
716+ std::vector<int > skip_layers = std::vector<int >()) {
716717 auto img_in = std::dynamic_pointer_cast<Linear>(blocks[" img_in" ]);
717718 auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks[" time_in" ]);
718719 auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks[" vector_in" ]);
@@ -734,6 +735,10 @@ namespace Flux {
734735 txt = txt_in->forward (ctx, txt);
735736
736737 for (int i = 0 ; i < params.depth ; i++) {
738+ if (skip_layers.size () > 0 && std::find (skip_layers.begin (), skip_layers.end (), i) != skip_layers.end ()) {
739+ continue ;
740+ }
741+
737742 auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks[" double_blocks." + std::to_string (i)]);
738743
739744 auto img_txt = block->forward (ctx, img, txt, vec, pe);
@@ -743,6 +748,9 @@ namespace Flux {
743748
744749 auto txt_img = ggml_concat (ctx, txt, img, 1 ); // [N, n_txt_token + n_img_token, hidden_size]
745750 for (int i = 0 ; i < params.depth_single_blocks ; i++) {
751+ if (skip_layers.size () > 0 && std::find (skip_layers.begin (), skip_layers.end (), i + params.depth ) != skip_layers.end ()) {
752+ continue ;
753+ }
746754 auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks[" single_blocks." + std::to_string (i)]);
747755
748756 txt_img = block->forward (ctx, txt_img, vec, pe);
@@ -770,7 +778,8 @@ namespace Flux {
770778 struct ggml_tensor * context,
771779 struct ggml_tensor * y,
772780 struct ggml_tensor * guidance,
773- struct ggml_tensor * pe) {
781+ struct ggml_tensor * pe,
782+ std::vector<int > skip_layers = std::vector<int >()) {
774783 // Forward pass of DiT.
775784 // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
776785 // timestep: (N,) tensor of diffusion timesteps
@@ -792,7 +801,7 @@ namespace Flux {
792801 // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
793802 auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
794803
795- auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size]
804+ auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, skip_layers ); // [N, h*w, C * patch_size * patch_size]
796805
797806 // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
798807 out = unpatchify (ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
@@ -833,7 +842,8 @@ namespace Flux {
833842 struct ggml_tensor * timesteps,
834843 struct ggml_tensor * context,
835844 struct ggml_tensor * y,
836- struct ggml_tensor * guidance) {
845+ struct ggml_tensor * guidance,
846+ std::vector<int > skip_layers = std::vector<int >()) {
837847 GGML_ASSERT (x->ne [3 ] == 1 );
838848 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
839849
@@ -860,7 +870,8 @@ namespace Flux {
860870 context,
861871 y,
862872 guidance,
863- pe);
873+ pe,
874+ skip_layers);
864875
865876 ggml_build_forward_expand (gf, out);
866877
@@ -874,14 +885,15 @@ namespace Flux {
874885 struct ggml_tensor * y,
875886 struct ggml_tensor * guidance,
876887 struct ggml_tensor ** output = NULL ,
877- struct ggml_context * output_ctx = NULL ) {
888+ struct ggml_context * output_ctx = NULL ,
889+ std::vector<int > skip_layers = std::vector<int >()) {
878890 // x: [N, in_channels, h, w]
879891 // timesteps: [N, ]
880892 // context: [N, max_position, hidden_size]
881893 // y: [N, adm_in_channels] or [1, adm_in_channels]
882894 // guidance: [N, ]
883895 auto get_graph = [&]() -> struct ggml_cgraph * {
884- return build_graph(x, timesteps, context, y, guidance);
896+ return build_graph(x, timesteps, context, y, guidance, skip_layers );
885897 };
886898
887899 GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
0 commit comments