Skip to content

Commit 4fe0e82

Browse files
committed
slg support for flux (expermiental)
1 parent 371e18b commit 4fe0e82

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

diffusion_model.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct UNetModel : public DiffusionModel {
7474
struct ggml_tensor** output = NULL,
7575
struct ggml_context* output_ctx = NULL,
7676
std::vector<int> skip_layers = std::vector<int>()) {
77+
(void)skip_layers; // SLG doesn't work with UNet models
7778
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
7879
}
7980
};
@@ -174,7 +175,7 @@ struct FluxModel : public DiffusionModel {
174175
struct ggml_tensor** output = NULL,
175176
struct ggml_context* output_ctx = NULL,
176177
std::vector<int> skip_layers = std::vector<int>()) {
177-
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
178+
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
178179
}
179180
};
180181

flux.hpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,8 @@ namespace Flux {
712712
struct ggml_tensor* timesteps,
713713
struct ggml_tensor* y,
714714
struct ggml_tensor* guidance,
715-
struct ggml_tensor* pe) {
715+
struct ggml_tensor* pe,
716+
std::vector<int> skip_layers = std::vector<int>()) {
716717
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
717718
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
718719
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
@@ -734,6 +735,10 @@ namespace Flux {
734735
txt = txt_in->forward(ctx, txt);
735736

736737
for (int i = 0; i < params.depth; i++) {
738+
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
739+
continue;
740+
}
741+
737742
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
738743

739744
auto img_txt = block->forward(ctx, img, txt, vec, pe);
@@ -743,6 +748,9 @@ namespace Flux {
743748

744749
auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
745750
for (int i = 0; i < params.depth_single_blocks; i++) {
751+
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
752+
continue;
753+
}
746754
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
747755

748756
txt_img = block->forward(ctx, txt_img, vec, pe);
@@ -770,7 +778,8 @@ namespace Flux {
770778
struct ggml_tensor* context,
771779
struct ggml_tensor* y,
772780
struct ggml_tensor* guidance,
773-
struct ggml_tensor* pe) {
781+
struct ggml_tensor* pe,
782+
std::vector<int> skip_layers = std::vector<int>()) {
774783
// Forward pass of DiT.
775784
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
776785
// timestep: (N,) tensor of diffusion timesteps
@@ -792,7 +801,7 @@ namespace Flux {
792801
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
793802
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
794803

795-
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size]
804+
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
796805

797806
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
798807
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
@@ -833,7 +842,8 @@ namespace Flux {
833842
struct ggml_tensor* timesteps,
834843
struct ggml_tensor* context,
835844
struct ggml_tensor* y,
836-
struct ggml_tensor* guidance) {
845+
struct ggml_tensor* guidance,
846+
std::vector<int> skip_layers = std::vector<int>()) {
837847
GGML_ASSERT(x->ne[3] == 1);
838848
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
839849

@@ -860,7 +870,8 @@ namespace Flux {
860870
context,
861871
y,
862872
guidance,
863-
pe);
873+
pe,
874+
skip_layers);
864875

865876
ggml_build_forward_expand(gf, out);
866877

@@ -874,14 +885,15 @@ namespace Flux {
874885
struct ggml_tensor* y,
875886
struct ggml_tensor* guidance,
876887
struct ggml_tensor** output = NULL,
877-
struct ggml_context* output_ctx = NULL) {
888+
struct ggml_context* output_ctx = NULL,
889+
std::vector<int> skip_layers = std::vector<int>()) {
878890
// x: [N, in_channels, h, w]
879891
// timesteps: [N, ]
880892
// context: [N, max_position, hidden_size]
881893
// y: [N, adm_in_channels] or [1, adm_in_channels]
882894
// guidance: [N, ]
883895
auto get_graph = [&]() -> struct ggml_cgraph* {
884-
return build_graph(x, timesteps, context, y, guidance);
896+
return build_graph(x, timesteps, context, y, guidance, skip_layers);
885897
};
886898

887899
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);

0 commit comments

Comments
 (0)