Skip to content

Commit 97495f4

Browse files
committed
lora: Lycoris LoKr support
1 parent 97112f3 commit 97495f4

File tree

2 files changed

+138
-58
lines changed

2 files changed

+138
-58
lines changed

ggml_extend.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct
104104
return updown;
105105
}
106106

107+
// Kronecker product
108+
// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10]
109+
__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) {
110+
return ggml_mul(ctx,
111+
ggml_upscale_ext(ctx,
112+
a,
113+
a->ne[0] * b->ne[0],
114+
a->ne[1] * b->ne[1],
115+
a->ne[2] * b->ne[2],
116+
a->ne[3] * b->ne[3]),
117+
b);
118+
}
119+
107120
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
108121
(void)level;
109122
(void)user_data;
@@ -1110,8 +1123,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
11101123
}
11111124

11121125
/* SDXL with LoRA requires more space */
1113-
#define MAX_PARAMS_TENSOR_NUM 16384
1114-
#define MAX_GRAPH_SIZE 16384
1126+
#define MAX_PARAMS_TENSOR_NUM 32768
1127+
#define MAX_GRAPH_SIZE 32768
11151128

11161129
struct GGMLRunner {
11171130
protected:

lora.hpp

Lines changed: 123 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ struct LoraModel : public GGMLRunner {
260260
}
261261
struct ggml_tensor* updown = NULL;
262262
float scale_value = 1.0f;
263-
if (lora_tensors.find(lora_pre[type] + key + ".hada_w1_a") != lora_tensors.end()) {
263+
std::string fk = lora_pre[type] + key;
264+
if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) {
264265
// loHa mode
265266

266267
std::string alpha_name = "";
@@ -282,49 +283,49 @@ struct LoraModel : public GGMLRunner {
282283
std::string hada_2_up_name = "";
283284

284285
// TODO: split qkv convention for LoHas (is it ever used?)
285-
if(is_qkv_split || is_qkvm_split){
286+
if (is_qkv_split || is_qkvm_split) {
286287
LOG_ERROR("Split qkv isn't supported for LoHa models.");
287288
break;
288289
}
289290

290-
hada_1_down_name = lora_pre[type] + key + ".hada_w1_b";
291-
hada_1_up_name = lora_pre[type] + key + ".hada_w1_a";
292-
hada_1_mid_name = lora_pre[type] + key + ".hada_t1";
291+
hada_1_down_name = fk + ".hada_w1_b";
292+
hada_1_up_name = fk + ".hada_w1_a";
293+
hada_1_mid_name = fk + ".hada_t1";
293294
if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) {
294-
hada_1_down = lora_tensors[hada_1_down_name];
295+
hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]);
295296
}
296297
if (lora_tensors.find(hada_1_up_name) != lora_tensors.end()) {
297-
hada_1_up = lora_tensors[hada_1_up_name];
298+
hada_1_up = to_f32(compute_ctx, lora_tensors[hada_1_up_name]);
298299
}
299300
if (lora_tensors.find(hada_1_mid_name) != lora_tensors.end()) {
300-
hada_1_mid = lora_tensors[hada_1_mid_name];
301+
hada_1_mid = to_f32(compute_ctx, lora_tensors[hada_1_mid_name]);
301302
applied_lora_tensors.insert(hada_1_mid_name);
302303
hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up));
303304
}
304305

305-
hada_2_down_name = lora_pre[type] + key + ".hada_w2_b";
306-
hada_2_up_name = lora_pre[type] + key + ".hada_w2_a";
307-
hada_2_mid_name = lora_pre[type] + key + ".hada_t2";
306+
hada_2_down_name = fk + ".hada_w2_b";
307+
hada_2_up_name = fk + ".hada_w2_a";
308+
hada_2_mid_name = fk + ".hada_t2";
308309
if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) {
309-
hada_2_down = lora_tensors[hada_2_down_name];
310+
hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]);
310311
}
311312
if (lora_tensors.find(hada_2_up_name) != lora_tensors.end()) {
312-
hada_2_up = lora_tensors[hada_2_up_name];
313+
hada_2_up = to_f32(compute_ctx, lora_tensors[hada_2_up_name]);
313314
}
314315
if (lora_tensors.find(hada_2_mid_name) != lora_tensors.end()) {
315-
hada_2_mid = lora_tensors[hada_2_mid_name];
316+
hada_2_mid = to_f32(compute_ctx, lora_tensors[hada_2_mid_name]);
316317
applied_lora_tensors.insert(hada_2_mid_name);
317318
hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up));
318319
}
319320

320-
alpha_name = lora_pre[type] + key + ".alpha";
321+
alpha_name = fk + ".alpha";
321322

322323
applied_lora_tensors.insert(hada_1_down_name);
323324
applied_lora_tensors.insert(hada_1_up_name);
324325
applied_lora_tensors.insert(hada_2_down_name);
325326
applied_lora_tensors.insert(hada_2_up_name);
326-
applied_lora_tensors.insert(alpha_name);
327327

328+
applied_lora_tensors.insert(alpha_name);
328329
if (hada_1_up == NULL || hada_1_down == NULL || hada_2_up == NULL || hada_2_down == NULL) {
329330
continue;
330331
}
@@ -340,9 +341,75 @@ struct LoraModel : public GGMLRunner {
340341
float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]);
341342
scale_value = alpha / dim;
342343
}
343-
} else if (lora_tensors.find(lora_pre[type] + key + ".lokr_w1") != lora_tensors.end()) {
344-
LOG_WARN("LoKr is not supported yet");
345-
break;
344+
} else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) {
345+
// LOG_WARN("LoKr is not supported yet");
346+
// break;
347+
std::string alpha_name = fk + ".alpha";
348+
;
349+
350+
ggml_tensor* lokr_w1 = NULL;
351+
ggml_tensor* lokr_w2 = NULL;
352+
353+
std::string lokr_w1_name = "";
354+
std::string lokr_w2_name = "";
355+
356+
// TODO: split qkv convention for LoKrs (is it ever used?)
357+
if (is_qkv_split || is_qkvm_split) {
358+
LOG_ERROR("Split qkv isn't supported for LoKr models.");
359+
break;
360+
}
361+
362+
lokr_w1_name = fk + ".lokr_w1";
363+
lokr_w2_name = fk + ".lokr_w2";
364+
365+
if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) {
366+
lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]);
367+
applied_lora_tensors.insert(lokr_w1_name);
368+
} else {
369+
ggml_tensor* down = NULL;
370+
ggml_tensor* up = NULL;
371+
std::string down_name = lokr_w1_name + "_b";
372+
std::string up_name = lokr_w1_name + "_a";
373+
if (lora_tensors.find(down_name) != lora_tensors.end()) {
374+
down = to_f32(compute_ctx, lora_tensors[down_name]);
375+
applied_lora_tensors.insert(down_name);
376+
377+
// scale != 1 only when using Low rank form (?)
378+
int64_t dim = down->ne[ggml_n_dims(down) - 1];
379+
if (lora_tensors.find(alpha_name) != lora_tensors.end()) {
380+
float alpha = ggml_backend_tensor_get_f32(to_f32(compute_ctx, lora_tensors[alpha_name]));
381+
scale_value = alpha / dim;
382+
}
383+
}
384+
if (lora_tensors.find(up_name) != lora_tensors.end()) {
385+
up = to_f32(compute_ctx, lora_tensors[up_name]);
386+
applied_lora_tensors.insert(up_name);
387+
}
388+
lokr_w1 = ggml_merge_lora(compute_ctx, down, up);
389+
}
390+
if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) {
391+
lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]);
392+
applied_lora_tensors.insert(lokr_w2_name);
393+
} else {
394+
ggml_tensor* down = NULL;
395+
ggml_tensor* up = NULL;
396+
std::string down_name = lokr_w2_name + "_b";
397+
std::string up_name = lokr_w2_name + "_a";
398+
if (lora_tensors.find(down_name) != lora_tensors.end()) {
399+
down = to_f32(compute_ctx, lora_tensors[down_name]);
400+
applied_lora_tensors.insert(down_name);
401+
}
402+
if (lora_tensors.find(up_name) != lora_tensors.end()) {
403+
up = to_f32(compute_ctx, lora_tensors[up_name]);
404+
applied_lora_tensors.insert(up_name);
405+
}
406+
lokr_w2 = ggml_merge_lora(compute_ctx, down, up);
407+
}
408+
409+
updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2);
410+
411+
// TODO: double check aplhas
412+
applied_lora_tensors.insert(alpha_name);
346413
} else {
347414
// LoRA mode
348415
ggml_tensor* lora_mid = NULL; // tau for tucker decomposition
@@ -358,29 +425,29 @@ struct LoraModel : public GGMLRunner {
358425

359426
if (is_qkv_split) {
360427
std::string suffix = "";
361-
auto split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight";
428+
auto split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
362429

363430
if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) {
364431
suffix = "_proj";
365-
split_q_d_name = lora_pre[type] + key + "q" + suffix + lora_downs[type] + ".weight";
432+
split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight";
366433
}
367434
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
368435
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
369436
// find qkv and mlp up parts in LoRA model
370-
auto split_k_d_name = lora_pre[type] + key + "k" + suffix + lora_downs[type] + ".weight";
371-
auto split_v_d_name = lora_pre[type] + key + "v" + suffix + lora_downs[type] + ".weight";
437+
auto split_k_d_name = fk + "k" + suffix + lora_downs[type] + ".weight";
438+
auto split_v_d_name = fk + "v" + suffix + lora_downs[type] + ".weight";
372439

373-
auto split_q_u_name = lora_pre[type] + key + "q" + suffix + lora_ups[type] + ".weight";
374-
auto split_k_u_name = lora_pre[type] + key + "k" + suffix + lora_ups[type] + ".weight";
375-
auto split_v_u_name = lora_pre[type] + key + "v" + suffix + lora_ups[type] + ".weight";
440+
auto split_q_u_name = fk + "q" + suffix + lora_ups[type] + ".weight";
441+
auto split_k_u_name = fk + "k" + suffix + lora_ups[type] + ".weight";
442+
auto split_v_u_name = fk + "v" + suffix + lora_ups[type] + ".weight";
376443

377-
auto split_q_scale_name = lora_pre[type] + key + "q" + suffix + ".scale";
378-
auto split_k_scale_name = lora_pre[type] + key + "k" + suffix + ".scale";
379-
auto split_v_scale_name = lora_pre[type] + key + "v" + suffix + ".scale";
444+
auto split_q_scale_name = fk + "q" + suffix + ".scale";
445+
auto split_k_scale_name = fk + "k" + suffix + ".scale";
446+
auto split_v_scale_name = fk + "v" + suffix + ".scale";
380447

381-
auto split_q_alpha_name = lora_pre[type] + key + "q" + suffix + ".alpha";
382-
auto split_k_alpha_name = lora_pre[type] + key + "k" + suffix + ".alpha";
383-
auto split_v_alpha_name = lora_pre[type] + key + "v" + suffix + ".alpha";
448+
auto split_q_alpha_name = fk + "q" + suffix + ".alpha";
449+
auto split_k_alpha_name = fk + "k" + suffix + ".alpha";
450+
auto split_v_alpha_name = fk + "v" + suffix + ".alpha";
384451

385452
ggml_tensor* lora_q_down = NULL;
386453
ggml_tensor* lora_q_up = NULL;
@@ -494,29 +561,29 @@ struct LoraModel : public GGMLRunner {
494561
applied_lora_tensors.insert(split_v_d_name);
495562
}
496563
} else if (is_qkvm_split) {
497-
auto split_q_d_name = lora_pre[type] + key + "attn.to_q" + lora_downs[type] + ".weight";
564+
auto split_q_d_name = fk + "attn.to_q" + lora_downs[type] + ".weight";
498565
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
499566
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
500567
// find qkv and mlp up parts in LoRA model
501-
auto split_k_d_name = lora_pre[type] + key + "attn.to_k" + lora_downs[type] + ".weight";
502-
auto split_v_d_name = lora_pre[type] + key + "attn.to_v" + lora_downs[type] + ".weight";
568+
auto split_k_d_name = fk + "attn.to_k" + lora_downs[type] + ".weight";
569+
auto split_v_d_name = fk + "attn.to_v" + lora_downs[type] + ".weight";
503570

504-
auto split_q_u_name = lora_pre[type] + key + "attn.to_q" + lora_ups[type] + ".weight";
505-
auto split_k_u_name = lora_pre[type] + key + "attn.to_k" + lora_ups[type] + ".weight";
506-
auto split_v_u_name = lora_pre[type] + key + "attn.to_v" + lora_ups[type] + ".weight";
571+
auto split_q_u_name = fk + "attn.to_q" + lora_ups[type] + ".weight";
572+
auto split_k_u_name = fk + "attn.to_k" + lora_ups[type] + ".weight";
573+
auto split_v_u_name = fk + "attn.to_v" + lora_ups[type] + ".weight";
507574

508-
auto split_m_d_name = lora_pre[type] + key + "proj_mlp" + lora_downs[type] + ".weight";
509-
auto split_m_u_name = lora_pre[type] + key + "proj_mlp" + lora_ups[type] + ".weight";
575+
auto split_m_d_name = fk + "proj_mlp" + lora_downs[type] + ".weight";
576+
auto split_m_u_name = fk + "proj_mlp" + lora_ups[type] + ".weight";
510577

511-
auto split_q_scale_name = lora_pre[type] + key + "attn.to_q" + ".scale";
512-
auto split_k_scale_name = lora_pre[type] + key + "attn.to_k" + ".scale";
513-
auto split_v_scale_name = lora_pre[type] + key + "attn.to_v" + ".scale";
514-
auto split_m_scale_name = lora_pre[type] + key + "proj_mlp" + ".scale";
578+
auto split_q_scale_name = fk + "attn.to_q" + ".scale";
579+
auto split_k_scale_name = fk + "attn.to_k" + ".scale";
580+
auto split_v_scale_name = fk + "attn.to_v" + ".scale";
581+
auto split_m_scale_name = fk + "proj_mlp" + ".scale";
515582

516-
auto split_q_alpha_name = lora_pre[type] + key + "attn.to_q" + ".alpha";
517-
auto split_k_alpha_name = lora_pre[type] + key + "attn.to_k" + ".alpha";
518-
auto split_v_alpha_name = lora_pre[type] + key + "attn.to_v" + ".alpha";
519-
auto split_m_alpha_name = lora_pre[type] + key + "proj_mlp" + ".alpha";
583+
auto split_q_alpha_name = fk + "attn.to_q" + ".alpha";
584+
auto split_k_alpha_name = fk + "attn.to_k" + ".alpha";
585+
auto split_v_alpha_name = fk + "attn.to_v" + ".alpha";
586+
auto split_m_alpha_name = fk + "proj_mlp" + ".alpha";
520587

521588
ggml_tensor* lora_q_down = NULL;
522589
ggml_tensor* lora_q_up = NULL;
@@ -671,23 +738,23 @@ struct LoraModel : public GGMLRunner {
671738
applied_lora_tensors.insert(split_m_d_name);
672739
}
673740
} else {
674-
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
675-
lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
676-
lora_mid_name = lora_pre[type] + key + ".lora_mid.weight";
741+
lora_up_name = fk + lora_ups[type] + ".weight";
742+
lora_down_name = fk + lora_downs[type] + ".weight";
743+
lora_mid_name = fk + ".lora_mid.weight";
677744

678-
alpha_name = lora_pre[type] + key + ".alpha";
679-
scale_name = lora_pre[type] + key + ".scale";
745+
alpha_name = fk + ".alpha";
746+
scale_name = fk + ".scale";
680747

681748
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
682-
lora_up = lora_tensors[lora_up_name];
749+
lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]);
683750
}
684751

685752
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
686-
lora_down = lora_tensors[lora_down_name];
753+
lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]);
687754
}
688755

689756
if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) {
690-
lora_mid = lora_tensors[lora_mid_name];
757+
lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]);
691758
applied_lora_tensors.insert(lora_mid_name);
692759
}
693760

0 commit comments

Comments
 (0)