@@ -672,11 +672,11 @@ namespace Flux {
672672 }
673673
674674 // Generate IDs for image patches and text
675- std::vector<std::vector<float >> gen_ids (int h, int w, int patch_size, int bs, int context_len ) {
675+ std::vector<std::vector<float >> gen_ids (int h, int w, int patch_size, int index = 0 ) {
676676 int h_len = (h + (patch_size / 2 )) / patch_size;
677677 int w_len = (w + (patch_size / 2 )) / patch_size;
678678
679- std::vector<std::vector<float >> img_ids (h_len * w_len, std::vector<float >(3 , 0.0 ));
679+ std::vector<std::vector<float >> img_ids (h_len * w_len, std::vector<float >(3 , ( float )index ));
680680
681681 std::vector<float > row_ids = linspace (0 , h_len - 1 , h_len);
682682 std::vector<float > col_ids = linspace (0 , w_len - 1 , w_len);
@@ -688,10 +688,22 @@ namespace Flux {
688688 }
689689 }
690690
691- std::vector<std::vector<float >> img_ids_repeated (bs * img_ids.size (), std::vector<float >(3 ));
692- for (int i = 0 ; i < bs; ++i) {
693- for (int j = 0 ; j < img_ids.size (); ++j) {
694- img_ids_repeated[i * img_ids.size () + j] = img_ids[j];
691+ return img_ids;
692+ }
693+
694+ // Generate positional embeddings
695+ std::vector<float > gen_pe (std::vector<struct ggml_tensor *> imgs, struct ggml_tensor * context, int patch_size, int theta, const std::vector<int >& axes_dim) {
696+ int context_len = context->ne [1 ];
697+ int bs = imgs[0 ]->ne [3 ];
698+
699+ std::vector<std::vector<float >> img_ids;
700+ for (int i = 0 ; i < imgs.size (); i++) {
701+ auto x = imgs[i];
702+ if (x) {
703+ int h = x->ne [1 ];
704+ int w = x->ne [0 ];
705+ std::vector<std::vector<float >> img_ids_i = gen_ids (h, w, patch_size, i);
706+ img_ids.insert (img_ids.end (), img_ids_i.begin (), img_ids_i.end ());
695707 }
696708 }
697709
@@ -702,17 +714,10 @@ namespace Flux {
702714 ids[i * (context_len + img_ids.size ()) + j] = txt_ids[j];
703715 }
704716 for (int j = 0 ; j < img_ids.size (); ++j) {
705- ids[i * (context_len + img_ids.size ()) + context_len + j] = img_ids_repeated[i * img_ids. size () + j];
717+ ids[i * (context_len + img_ids.size ()) + context_len + j] = img_ids[ j];
706718 }
707719 }
708720
709- return ids;
710- }
711-
712-
713- // Generate positional embeddings
714- std::vector<float > gen_pe (int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int >& axes_dim) {
715- std::vector<std::vector<float >> ids = gen_ids (h, w, patch_size, bs, context_len);
716721 std::vector<std::vector<float >> trans_ids = transpose (ids);
717722 size_t pos_len = ids.size ();
718723 int num_axes = axes_dim.size ();
@@ -932,6 +937,7 @@ namespace Flux {
932937 struct ggml_tensor * y,
933938 struct ggml_tensor * guidance,
934939 struct ggml_tensor * pe,
940+ bool kontext_concat = false ,
935941 struct ggml_tensor * arange = NULL ,
936942 std::vector<int > skip_layers = std::vector<int >(),
937943 SDVersion version = VERSION_FLUX) {
@@ -956,8 +962,8 @@ namespace Flux {
956962 x = ggml_pad (ctx, x, pad_w, pad_h, 0 , 0 ); // [N, C, H + pad_h, W + pad_w]
957963
958964 // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
959- auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
960-
965+ auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
966+ int64_t patchified_img_size = img-> ne [ 1 ];
961967 if (version == VERSION_FLUX_FILL) {
962968 GGML_ASSERT (c_concat != NULL );
963969 ggml_tensor* masked = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], 0 );
@@ -993,10 +999,16 @@ namespace Flux {
993999 control = patchify (ctx, control, patch_size);
9941000
9951001 img = ggml_concat (ctx, img, control, 0 );
1002+ } else if (kontext_concat && c_concat != NULL ) {
1003+ ggml_tensor* kontext = ggml_pad (ctx, c_concat, pad_w, pad_h, 0 , 0 );
1004+ kontext = patchify (ctx, kontext, patch_size);
1005+ img = ggml_concat (ctx, img, kontext, 1 );
9961006 }
9971007
9981008 auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size]
9991009
1010+ out = ggml_cont (ctx, ggml_view_2d (ctx, out, out->ne [0 ], patchified_img_size, out->nb [1 ], 0 ));
1011+
10001012 // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
10011013 out = unpatchify (ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
10021014
@@ -1010,7 +1022,7 @@ namespace Flux {
10101022 public:
10111023 FluxParams flux_params;
10121024 Flux flux;
1013- std::vector<float > pe_vec, range; // for cache
1025+ std::vector<float > pe_vec, concat_pe_vec, range; // for cache
10141026 SDVersion version;
10151027
10161028 FluxRunner (ggml_backend_t backend,
@@ -1085,6 +1097,7 @@ namespace Flux {
10851097 struct ggml_tensor * c_concat,
10861098 struct ggml_tensor * y,
10871099 struct ggml_tensor * guidance,
1100+ bool kontext_concat = false ,
10881101 std::vector<int > skip_layers = std::vector<int >()) {
10891102 GGML_ASSERT (x->ne [3 ] == 1 );
10901103 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
@@ -1113,7 +1126,6 @@ namespace Flux {
11131126 guidance = ggml_set_f32 (guidance, 0 );
11141127 }
11151128
1116-
11171129 const char * SD_CHROMA_USE_DIT_MASK = getenv (" SD_CHROMA_USE_DIT_MASK" );
11181130 if (SD_CHROMA_USE_DIT_MASK != nullptr ) {
11191131 std::string sd_chroma_use_DiT_mask_str = SD_CHROMA_USE_DIT_MASK;
@@ -1137,7 +1149,12 @@ namespace Flux {
11371149 guidance = to_backend (guidance);
11381150 }
11391151
1140- pe_vec = flux.gen_pe (x->ne [1 ], x->ne [0 ], 2 , x->ne [3 ], context->ne [1 ], flux_params.theta , flux_params.axes_dim );
1152+ std::vector<struct ggml_tensor *> imgs{x};
1153+ if (kontext_concat && c_concat != NULL ) {
1154+ imgs.push_back (c_concat);
1155+ }
1156+
1157+ pe_vec = flux.gen_pe (imgs, context, 2 , flux_params.theta , flux_params.axes_dim );
11411158 int pos_len = pe_vec.size () / flux_params.axes_dim_sum / 2 ;
11421159 // LOG_DEBUG("pos_len %d", pos_len);
11431160 auto pe = ggml_new_tensor_4d (compute_ctx, GGML_TYPE_F32, 2 , 2 , flux_params.axes_dim_sum / 2 , pos_len);
@@ -1146,6 +1163,17 @@ namespace Flux {
11461163 // pe->data = NULL;
11471164 set_backend_tensor_data (pe, pe_vec.data ());
11481165
1166+ // if (kontext_concat && c_concat != NULL) {
1167+ // // TODO: offsets
1168+ // concat_pe_vec = flux.gen_pe(x, context, 2, flux_params.theta, flux_params.axes_dim);
1169+ // int pos_len = concat_pe_vec.size() / flux_params.axes_dim_sum / 2;
1170+
1171+ // auto concat_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
1172+
1173+ // set_backend_tensor_data(concat_pe, concat_pe_vec.data());
1174+ // pe = ggml_concat(compute_ctx, pe, concat_pe, 3);
1175+ // }
1176+
11491177 struct ggml_tensor * out = flux.forward (compute_ctx,
11501178 x,
11511179 timesteps,
@@ -1154,6 +1182,7 @@ namespace Flux {
11541182 y,
11551183 guidance,
11561184 pe,
1185+ kontext_concat,
11571186 precompute_arange,
11581187 skip_layers,
11591188 version);
@@ -1170,6 +1199,7 @@ namespace Flux {
11701199 struct ggml_tensor * c_concat,
11711200 struct ggml_tensor * y,
11721201 struct ggml_tensor * guidance,
1202+ bool kontext_concat = false ,
11731203 struct ggml_tensor ** output = NULL ,
11741204 struct ggml_context * output_ctx = NULL ,
11751205 std::vector<int > skip_layers = std::vector<int >()) {
@@ -1179,7 +1209,7 @@ namespace Flux {
11791209 // y: [N, adm_in_channels] or [1, adm_in_channels]
11801210 // guidance: [N, ]
11811211 auto get_graph = [&]() -> struct ggml_cgraph * {
1182- return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
1212+ return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_concat, skip_layers);
11831213 };
11841214
11851215 return GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
@@ -1219,10 +1249,9 @@ namespace Flux {
12191249 struct ggml_tensor * out = NULL ;
12201250
12211251 int t0 = ggml_time_ms ();
1222- compute (8 , x, timesteps, context, NULL , y, guidance, &out, work_ctx);
1252+ compute (8 , x, timesteps, context, NULL , y, guidance, false , &out, work_ctx);
12231253 int t1 = ggml_time_ms ();
12241254
1225- print_ggml_tensor (out);
12261255 LOG_DEBUG (" flux test done in %dms" , t1 - t0);
12271256 }
12281257 }
0 commit comments