Skip to content

Commit 6bceb65

Browse files
committed
Initial Kontext support (WIP)
1 parent ac708e8 commit 6bceb65

File tree

3 files changed

+89
-47
lines changed

3 files changed

+89
-47
lines changed

diffusion_model.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct DiffusionModel {
1616
int num_video_frames = -1,
1717
std::vector<struct ggml_tensor*> controls = {},
1818
float control_strength = 0.f,
19+
bool kontext_concat = false,
1920
struct ggml_tensor** output = NULL,
2021
struct ggml_context* output_ctx = NULL,
2122
std::vector<int> skip_layers = std::vector<int>()) = 0;
@@ -71,6 +72,7 @@ struct UNetModel : public DiffusionModel {
7172
int num_video_frames = -1,
7273
std::vector<struct ggml_tensor*> controls = {},
7374
float control_strength = 0.f,
75+
bool kontext_concat = false,
7476
struct ggml_tensor** output = NULL,
7577
struct ggml_context* output_ctx = NULL,
7678
std::vector<int> skip_layers = std::vector<int>()) {
@@ -121,6 +123,7 @@ struct MMDiTModel : public DiffusionModel {
121123
int num_video_frames = -1,
122124
std::vector<struct ggml_tensor*> controls = {},
123125
float control_strength = 0.f,
126+
bool kontext_concat = false,
124127
struct ggml_tensor** output = NULL,
125128
struct ggml_context* output_ctx = NULL,
126129
std::vector<int> skip_layers = std::vector<int>()) {
@@ -172,10 +175,11 @@ struct FluxModel : public DiffusionModel {
172175
int num_video_frames = -1,
173176
std::vector<struct ggml_tensor*> controls = {},
174177
float control_strength = 0.f,
178+
bool kontext_concat = false,
175179
struct ggml_tensor** output = NULL,
176180
struct ggml_context* output_ctx = NULL,
177181
std::vector<int> skip_layers = std::vector<int>()) {
178-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
182+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_concat, output, output_ctx, skip_layers);
179183
}
180184
};
181185

flux.hpp

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -672,11 +672,11 @@ namespace Flux {
672672
}
673673

674674
// Generate IDs for image patches and text
675-
std::vector<std::vector<float>> gen_ids(int h, int w, int patch_size, int bs, int context_len) {
675+
std::vector<std::vector<float>> gen_ids(int h, int w, int patch_size, int index = 0) {
676676
int h_len = (h + (patch_size / 2)) / patch_size;
677677
int w_len = (w + (patch_size / 2)) / patch_size;
678678

679-
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
679+
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, (float)index));
680680

681681
std::vector<float> row_ids = linspace(0, h_len - 1, h_len);
682682
std::vector<float> col_ids = linspace(0, w_len - 1, w_len);
@@ -688,10 +688,22 @@ namespace Flux {
688688
}
689689
}
690690

691-
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
692-
for (int i = 0; i < bs; ++i) {
693-
for (int j = 0; j < img_ids.size(); ++j) {
694-
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
691+
return img_ids;
692+
}
693+
694+
// Generate positional embeddings
695+
std::vector<float> gen_pe(std::vector<struct ggml_tensor*> imgs, struct ggml_tensor* context, int patch_size, int theta, const std::vector<int>& axes_dim) {
696+
int context_len = context->ne[1];
697+
int bs = imgs[0]->ne[3];
698+
699+
std::vector<std::vector<float>> img_ids;
700+
for (int i = 0; i < imgs.size(); i++) {
701+
auto x = imgs[i];
702+
if (x) {
703+
int h = x->ne[1];
704+
int w = x->ne[0];
705+
std::vector<std::vector<float>> img_ids_i = gen_ids(h, w, patch_size, i);
706+
img_ids.insert(img_ids.end(), img_ids_i.begin(), img_ids_i.end());
695707
}
696708
}
697709

@@ -702,17 +714,10 @@ namespace Flux {
702714
ids[i * (context_len + img_ids.size()) + j] = txt_ids[j];
703715
}
704716
for (int j = 0; j < img_ids.size(); ++j) {
705-
ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j];
717+
ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids[j];
706718
}
707719
}
708720

709-
return ids;
710-
}
711-
712-
713-
// Generate positional embeddings
714-
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
715-
std::vector<std::vector<float>> ids = gen_ids(h, w, patch_size, bs, context_len);
716721
std::vector<std::vector<float>> trans_ids = transpose(ids);
717722
size_t pos_len = ids.size();
718723
int num_axes = axes_dim.size();
@@ -932,6 +937,7 @@ namespace Flux {
932937
struct ggml_tensor* y,
933938
struct ggml_tensor* guidance,
934939
struct ggml_tensor* pe,
940+
bool kontext_concat = false,
935941
struct ggml_tensor* arange = NULL,
936942
std::vector<int> skip_layers = std::vector<int>(),
937943
SDVersion version = VERSION_FLUX) {
@@ -956,8 +962,8 @@ namespace Flux {
956962
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
957963

958964
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
959-
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
960-
965+
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
966+
int64_t patchified_img_size = img->ne[1];
961967
if (version == VERSION_FLUX_FILL) {
962968
GGML_ASSERT(c_concat != NULL);
963969
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
@@ -993,10 +999,16 @@ namespace Flux {
993999
control = patchify(ctx, control, patch_size);
9941000

9951001
img = ggml_concat(ctx, img, control, 0);
1002+
} else if (kontext_concat && c_concat != NULL) {
1003+
ggml_tensor* kontext = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
1004+
kontext = patchify(ctx, kontext, patch_size);
1005+
img = ggml_concat(ctx, img, kontext, 1);
9961006
}
9971007

9981008
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size]
9991009

1010+
out = ggml_cont(ctx, ggml_view_2d(ctx, out, out->ne[0], patchified_img_size, out->nb[1], 0));
1011+
10001012
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
10011013
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
10021014

@@ -1010,7 +1022,7 @@ namespace Flux {
10101022
public:
10111023
FluxParams flux_params;
10121024
Flux flux;
1013-
std::vector<float> pe_vec, range; // for cache
1025+
std::vector<float> pe_vec, concat_pe_vec, range; // for cache
10141026
SDVersion version;
10151027

10161028
FluxRunner(ggml_backend_t backend,
@@ -1085,6 +1097,7 @@ namespace Flux {
10851097
struct ggml_tensor* c_concat,
10861098
struct ggml_tensor* y,
10871099
struct ggml_tensor* guidance,
1100+
bool kontext_concat = false,
10881101
std::vector<int> skip_layers = std::vector<int>()) {
10891102
GGML_ASSERT(x->ne[3] == 1);
10901103
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
@@ -1113,7 +1126,6 @@ namespace Flux {
11131126
guidance = ggml_set_f32(guidance, 0);
11141127
}
11151128

1116-
11171129
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
11181130
if (SD_CHROMA_USE_DIT_MASK != nullptr) {
11191131
std::string sd_chroma_use_DiT_mask_str = SD_CHROMA_USE_DIT_MASK;
@@ -1137,7 +1149,12 @@ namespace Flux {
11371149
guidance = to_backend(guidance);
11381150
}
11391151

1140-
pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim);
1152+
std::vector<struct ggml_tensor*> imgs{x};
1153+
if (kontext_concat && c_concat != NULL) {
1154+
imgs.push_back(c_concat);
1155+
}
1156+
1157+
pe_vec = flux.gen_pe(imgs, context, 2, flux_params.theta, flux_params.axes_dim);
11411158
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
11421159
// LOG_DEBUG("pos_len %d", pos_len);
11431160
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
@@ -1146,6 +1163,17 @@ namespace Flux {
11461163
// pe->data = NULL;
11471164
set_backend_tensor_data(pe, pe_vec.data());
11481165

1166+
// if (kontext_concat && c_concat != NULL) {
1167+
// // TODO: offsets
1168+
// concat_pe_vec = flux.gen_pe(x, context, 2, flux_params.theta, flux_params.axes_dim);
1169+
// int pos_len = concat_pe_vec.size() / flux_params.axes_dim_sum / 2;
1170+
1171+
// auto concat_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
1172+
1173+
// set_backend_tensor_data(concat_pe, concat_pe_vec.data());
1174+
// pe = ggml_concat(compute_ctx, pe, concat_pe, 3);
1175+
// }
1176+
11491177
struct ggml_tensor* out = flux.forward(compute_ctx,
11501178
x,
11511179
timesteps,
@@ -1154,6 +1182,7 @@ namespace Flux {
11541182
y,
11551183
guidance,
11561184
pe,
1185+
kontext_concat,
11571186
precompute_arange,
11581187
skip_layers,
11591188
version);
@@ -1170,6 +1199,7 @@ namespace Flux {
11701199
struct ggml_tensor* c_concat,
11711200
struct ggml_tensor* y,
11721201
struct ggml_tensor* guidance,
1202+
bool kontext_concat = false,
11731203
struct ggml_tensor** output = NULL,
11741204
struct ggml_context* output_ctx = NULL,
11751205
std::vector<int> skip_layers = std::vector<int>()) {
@@ -1179,7 +1209,7 @@ namespace Flux {
11791209
// y: [N, adm_in_channels] or [1, adm_in_channels]
11801210
// guidance: [N, ]
11811211
auto get_graph = [&]() -> struct ggml_cgraph* {
1182-
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
1212+
return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_concat, skip_layers);
11831213
};
11841214

11851215
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@@ -1219,10 +1249,9 @@ namespace Flux {
12191249
struct ggml_tensor* out = NULL;
12201250

12211251
int t0 = ggml_time_ms();
1222-
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
1252+
compute(8, x, timesteps, context, NULL, y, guidance, false, &out, work_ctx);
12231253
int t1 = ggml_time_ms();
12241254

1225-
print_ggml_tensor(out);
12261255
LOG_DEBUG("flux test done in %dms", t1 - t0);
12271256
}
12281257
}

0 commit comments

Comments
 (0)