@@ -244,12 +244,15 @@ struct LoraModel : public GGMLRunner {
244244 std::vector<std::string> keys = to_lora_keys (k_tensor, version);
245245 if (keys.size () == 0 )
246246 continue ;
247+
248+ ggml_tensor* lora_mid = NULL ; // tau for tucker decomposition
247249 ggml_tensor* lora_up = NULL ;
248250 ggml_tensor* lora_down = NULL ;
249251 for (auto & key : keys) {
250252 std::string alpha_name = " " ;
251253 std::string scale_name = " " ;
252254 std::string split_q_scale_name = " " ;
255+ std::string lora_mid_name = " " ;
253256 std::string lora_down_name = " " ;
254257 std::string lora_up_name = " " ;
255258
@@ -584,8 +587,10 @@ struct LoraModel : public GGMLRunner {
584587 }
585588
586589 lora_down_name = lora_pre[type] + key + lora_downs[type] + " .weight" ;
587- alpha_name = lora_pre[type] + key + " .alpha" ;
588- scale_name = lora_pre[type] + key + " .scale" ;
590+ lora_mid_name = lora_pre[type] + key + " .lora_mid.weight" ;
591+
592+ alpha_name = lora_pre[type] + key + " .alpha" ;
593+ scale_name = lora_pre[type] + key + " .scale" ;
589594
590595 if (lora_tensors.find (lora_up_name) != lora_tensors.end ()) {
591596 lora_up = lora_tensors[lora_up_name];
@@ -594,6 +599,12 @@ struct LoraModel : public GGMLRunner {
594599 if (lora_tensors.find (lora_down_name) != lora_tensors.end ()) {
595600 lora_down = lora_tensors[lora_down_name];
596601 }
602+
603+ if (lora_tensors.find (lora_mid_name) != lora_tensors.end ()) {
604+ lora_mid = lora_tensors[lora_mid_name];
605+ applied_lora_tensors.insert (lora_mid_name);
606+ }
607+
597608 applied_lora_tensors.insert (lora_up_name);
598609 applied_lora_tensors.insert (lora_down_name);
599610 applied_lora_tensors.insert (alpha_name);
@@ -625,9 +636,20 @@ struct LoraModel : public GGMLRunner {
625636
626637 // ggml_mul_mat requires tensor b transposed
627638 lora_down = ggml_cont (compute_ctx, ggml_transpose (compute_ctx, lora_down));
628- struct ggml_tensor * updown = ggml_mul_mat (compute_ctx, lora_up, lora_down);
629- updown = ggml_cont (compute_ctx, ggml_transpose (compute_ctx, updown));
630- updown = ggml_reshape (compute_ctx, updown, weight);
639+ struct ggml_tensor * updown = NULL ;
640+ if (lora_mid == NULL ) {
641+ updown = ggml_mul_mat (compute_ctx, lora_up, lora_down);
642+ updown = ggml_cont (compute_ctx, ggml_transpose (compute_ctx, updown));
643+ } else {
644+ // undoing tucker decomposition for conv layers.
645+ // lora_mid has shape (3, 3, Rank, Rank)
646+ // lora_down has shape (Rank, In, 1, 1)
647+ // lora_up has shape (Rank, Out, 1, 1)
648+ // conv layer shape is (3, 3, Out, In)
649+ updown = ggml_mul_n_mode (compute_ctx, ggml_mul_n_mode (compute_ctx, lora_mid, lora_down, 3 ), lora_up, 2 );
650+ updown = ggml_cont (compute_ctx, updown);
651+ }
652+ updown = ggml_reshape (compute_ctx, updown, weight);
631653 GGML_ASSERT (ggml_nelements (updown) == ggml_nelements (weight));
632654 updown = ggml_scale_inplace (compute_ctx, updown, scale_value);
633655 ggml_tensor* final_weight;
0 commit comments