Skip to content

Commit 2210257

Browse files
committed
Flux Lite (Freepik) support
1 parent a10dd7c commit 2210257

File tree

6 files changed

+68
-21
lines changed

6 files changed

+68
-21
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
853853

854854
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
855855
latent_rgb_proj = sd3_latent_rgb_proj;
856-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
856+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL|| version == VERSION_FLUX_LITE) {
857857
latent_rgb_proj = flux_latent_rgb_proj;
858858
} else {
859859
// unknown model

flux.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,9 @@ namespace Flux {
826826
if (version == VERSION_FLUX_SCHNELL) {
827827
flux_params.guidance_embed = false;
828828
}
829+
if (version == VERSION_FLUX_LITE){
830+
flux_params.depth = 8;
831+
}
829832
flux = Flux(flux_params);
830833
flux.init(params_ctx, tensor_types, prefix);
831834
}

model.cpp

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,15 +1393,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
13931393

13941394
SDVersion ModelLoader::get_sd_version() {
13951395
TensorStorage token_embedding_weight;
1396-
bool is_flux = false;
1397-
bool is_sd3 = false;
1396+
bool is_flux = false;
1397+
bool is_schnell = true;
1398+
bool is_lite = true;
1399+
bool is_sd3 = false;
13981400
for (auto& tensor_storage : tensor_storages) {
13991401
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
1400-
return VERSION_FLUX_DEV;
1402+
is_schnell = false;
14011403
}
14021404
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
14031405
is_flux = true;
14041406
}
1407+
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
1408+
is_lite = false;
1409+
}
14051410
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
14061411
return VERSION_SD3_5_2B;
14071412
}
@@ -1432,7 +1437,14 @@ SDVersion ModelLoader::get_sd_version() {
14321437
}
14331438
}
14341439
if (is_flux) {
1435-
return VERSION_FLUX_SCHNELL;
1440+
if (is_schnell) {
1441+
GGML_ASSERT(!is_lite);
1442+
return VERSION_FLUX_SCHNELL;
1443+
} else if (is_lite) {
1444+
return VERSION_FLUX_LITE;
1445+
} else {
1446+
return VERSION_FLUX_DEV;
1447+
}
14361448
}
14371449
if (is_sd3) {
14381450
return VERSION_SD3_2B;
@@ -1856,7 +1868,21 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
18561868
const std::string& name = tensor_storage.name;
18571869

18581870
ggml_type tensor_type = tensor_storage.type;
1859-
tensor_set_type(tensor_type, tensor_storage, type, fallback_type);
1871+
auto _type = type;
1872+
// attemmpt to improve q2_k quant by using higher quants for final blocks
1873+
if (type == GGML_TYPE_Q2_K) {
1874+
if (name.find("single_blocks.37") != std::string::npos ||
1875+
name.find("double_blocks.0") != std::string::npos) {
1876+
_type = GGML_TYPE_Q4_K;
1877+
} else if (name.find("single_blocks.36") != std::string::npos ||
1878+
name.find("single_blocks.35") != std::string::npos ||
1879+
name.find("single_blocks.0") != std::string::npos ||
1880+
name.find("double_blocks.18") != std::string::npos ||
1881+
name.find("double_blocks.1") != std::string::npos) {
1882+
_type = GGML_TYPE_Q3_K;
1883+
}
1884+
}
1885+
tensor_set_type(tensor_type, tensor_storage, _type, fallback_type);
18601886

18611887
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
18621888
if (tensor == NULL) {
@@ -1890,7 +1916,8 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
18901916
return success;
18911917
}
18921918

1893-
int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type, ggml_type fallback_type /*= GGML_TYPE_COUNT*/) {
1919+
int64_t
1920+
ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type, ggml_type fallback_type /*= GGML_TYPE_COUNT*/) {
18941921
size_t alignment = 128;
18951922
if (backend != NULL) {
18961923
alignment = ggml_backend_get_alignment(backend);
@@ -1905,7 +1932,22 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type,
19051932
}
19061933

19071934
for (auto& tensor_storage : processed_tensor_storages) {
1908-
tensor_set_type(tensor_storage.type, tensor_storage, type, fallback_type);
1935+
auto _type = type;
1936+
auto name = tensor_storage.name;
1937+
// attemmpt to improve q2_k quant by using higher quants for final blocks
1938+
if (type == GGML_TYPE_Q2_K) {
1939+
if (name.find("single_blocks.37") != std::string::npos ||
1940+
name.find("double_blocks.0") != std::string::npos) {
1941+
_type = GGML_TYPE_Q4_K;
1942+
} else if (name.find("single_blocks.36") != std::string::npos ||
1943+
name.find("single_blocks.35") != std::string::npos ||
1944+
name.find("single_blocks.0") != std::string::npos ||
1945+
name.find("double_blocks.18") != std::string::npos ||
1946+
name.find("double_blocks.1") != std::string::npos) {
1947+
_type = GGML_TYPE_Q3_K;
1948+
}
1949+
}
1950+
tensor_set_type(tensor_storage.type, tensor_storage, _type, fallback_type);
19091951
mem_size += tensor_storage.nbytes() + alignment;
19101952
}
19111953

model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ enum SDVersion {
2727
VERSION_FLUX_SCHNELL,
2828
VERSION_SD3_5_8B,
2929
VERSION_SD3_5_2B,
30+
VERSION_FLUX_LITE,
3031
VERSION_COUNT,
3132
};
3233

stable-diffusion.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ const char* model_version_to_str[] = {
3333
"Flux Dev",
3434
"Flux Schnell",
3535
"SD3.5 8B",
36-
"SD3.5 2B"};
36+
"SD3.5 2B",
37+
"Flux Lite 8B"};
3738

3839
const char* sampling_methods_str[] = {
3940
"Euler A",
@@ -291,7 +292,7 @@ class StableDiffusionGGML {
291292
}
292293
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
293294
scale_factor = 1.5305f;
294-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
295+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
295296
scale_factor = 0.3611;
296297
// TODO: shift_factor
297298
}
@@ -312,7 +313,7 @@ class StableDiffusionGGML {
312313
} else {
313314
clip_backend = backend;
314315
bool use_t5xxl = false;
315-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
316+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
316317
use_t5xxl = true;
317318
}
318319
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
@@ -326,7 +327,7 @@ class StableDiffusionGGML {
326327
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
327328
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
328329
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types, version);
329-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
330+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
330331
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
331332
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version);
332333
} else {
@@ -525,7 +526,7 @@ class StableDiffusionGGML {
525526
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
526527
LOG_INFO("running in FLOW mode");
527528
denoiser = std::make_shared<DiscreteFlowDenoiser>();
528-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
529+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
529530
LOG_INFO("running in Flux FLOW mode");
530531
float shift = 1.15f;
531532
if (version == VERSION_FLUX_SCHNELL) {
@@ -811,7 +812,7 @@ class StableDiffusionGGML {
811812
out_uncond = ggml_dup_tensor(tmp_ctx, x);
812813
}
813814
if (has_skiplayer) {
814-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
815+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
815816
out_skip = ggml_dup_tensor(tmp_ctx, x);
816817
} else {
817818
has_skiplayer = false;
@@ -1008,7 +1009,7 @@ class StableDiffusionGGML {
10081009
} else {
10091010
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
10101011
C = 32;
1011-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
1012+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
10121013
C = 32;
10131014
}
10141015
}
@@ -1346,7 +1347,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13461347
int C = 4;
13471348
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
13481349
C = 16;
1349-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1350+
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
13501351
C = 16;
13511352
}
13521353
int W = width / 8;
@@ -1471,7 +1472,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14711472
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
14721473
params.mem_size *= 3;
14731474
}
1474-
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1475+
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
14751476
params.mem_size *= 4;
14761477
}
14771478
if (sd_ctx->sd->stacked_id) {
@@ -1496,15 +1497,15 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14961497
int C = 4;
14971498
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
14981499
C = 16;
1499-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1500+
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
15001501
C = 16;
15011502
}
15021503
int W = width / 8;
15031504
int H = height / 8;
15041505
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
15051506
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
15061507
ggml_set_f32(init_latent, 0.0609f);
1507-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1508+
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
15081509
ggml_set_f32(init_latent, 0.1159f);
15091510
} else {
15101511
ggml_set_f32(init_latent, 0.f);
@@ -1575,7 +1576,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
15751576
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
15761577
params.mem_size *= 2;
15771578
}
1578-
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1579+
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
15791580
params.mem_size *= 3;
15801581
}
15811582
if (sd_ctx->sd->stacked_id) {

vae.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ class AutoencodingEngine : public GGMLBlock {
458458
bool use_video_decoder = false,
459459
SDVersion version = VERSION_SD1)
460460
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
461-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
461+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
462462
dd_config.z_channels = 16;
463463
use_quant = false;
464464
}

0 commit comments

Comments
 (0)