|
38 | 38 | C3AConfig, |
39 | 39 | DeloraConfig, |
40 | 40 | FourierFTConfig, |
| 41 | + GraloraConfig, |
41 | 42 | HRAConfig, |
42 | 43 | IA3Config, |
43 | 44 | LNTuningConfig, |
|
666 | 667 | "init_weights": True, |
667 | 668 | }, |
668 | 669 | ), |
| 670 | + ########### |
| 671 | + # GraLoRA # |
| 672 | + ########### |
| 673 | + ("Vanilla MLP 1 GraLoRA", "MLP", GraloraConfig, {"target_modules": "lin0"}), |
| 674 | + ("Vanilla MLP 2 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0"]}), |
| 675 | + ("Vanilla MLP 3 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin1"]}), |
| 676 | + ("Vanilla MLP 4 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0", "lin1"]}), |
| 677 | + ( |
| 678 | + "Vanilla MLP 5 GraLoRA", |
| 679 | + "MLP", |
| 680 | + GraloraConfig, |
| 681 | + {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}, |
| 682 | + ), |
| 683 | + ( |
| 684 | + "Embedding + transformers Conv1D 1 GraLoRA", |
| 685 | + "EmbConv1D", |
| 686 | + GraloraConfig, |
| 687 | + {"target_modules": ["conv1d"], "gralora_k": 1}, |
| 688 | + ), |
669 | 689 | ########## |
670 | 690 | # VBLoRA # |
671 | 691 | ########## |
|
979 | 999 | {"n_frequency": 10, "target_modules": ["lin0"]}, |
980 | 1000 | {"n_frequency": 10, "target_modules": ["lin1"]}, |
981 | 1001 | ), |
| 1002 | + ( |
| 1003 | + "GraLoRA Same", |
| 1004 | + "gralora", |
| 1005 | + GraloraConfig, |
| 1006 | + {"target_modules": ["lin0"], "init_weights": False}, |
| 1007 | + {"target_modules": ["lin0"], "init_weights": False}, |
| 1008 | + ), |
| 1009 | + ( |
| 1010 | + "GraLoRA Different", |
| 1011 | + "gralora", |
| 1012 | + GraloraConfig, |
| 1013 | + {"target_modules": ["lin0"], "init_weights": False}, |
| 1014 | + {"target_modules": ["lin1"], "init_weights": False}, |
| 1015 | + ), |
982 | 1016 | ( |
983 | 1017 | "SHiRA Same", |
984 | 1018 | "shira", |
|
1165 | 1199 | VeraConfig: "vera_lambda_", |
1166 | 1200 | RandLoraConfig: "randlora_", |
1167 | 1201 | FourierFTConfig: "fourierft_", |
| 1202 | + GraloraConfig: "gralora_", |
1168 | 1203 | C3AConfig: "c3a_", |
1169 | 1204 | HRAConfig: "hra_", |
1170 | 1205 | ShiraConfig: "shira_", |
@@ -3089,12 +3124,12 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self): |
3089 | 3124 | cancelled_B = module.lora_B["cancelled"].weight.data |
3090 | 3125 |
|
3091 | 3126 | # The weights should be approximately zero (they cancel out) |
3092 | | - assert torch.allclose(cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5), ( |
3093 | | - f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}" |
3094 | | - ) |
3095 | | - assert torch.allclose(cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5), ( |
3096 | | - f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}" |
3097 | | - ) |
| 3127 | + assert torch.allclose( |
| 3128 | + cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5 |
| 3129 | + ), f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}" |
| 3130 | + assert torch.allclose( |
| 3131 | + cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5 |
| 3132 | + ), f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}" |
3098 | 3133 |
|
3099 | 3134 | def test_add_weighted_adapter_negative_weight_with_different_scaling(self): |
3100 | 3135 | # Test negative weights with different scaling factors (lora_alpha) |
@@ -3500,9 +3535,9 @@ def test_multirank_2(self): |
3500 | 3535 | if isinstance(module, BaseTunerLayer): |
3501 | 3536 | rank_expected = rank_pattern.get(key, r) |
3502 | 3537 | rank_current = module.lora_A[adapter].weight.shape[0] |
3503 | | - assert rank_current == rank_expected, ( |
3504 | | - f"Rank {rank_current} is not equal to expected {rank_expected}" |
3505 | | - ) |
| 3538 | + assert ( |
| 3539 | + rank_current == rank_expected |
| 3540 | + ), f"Rank {rank_current} is not equal to expected {rank_expected}" |
3506 | 3541 |
|
3507 | 3542 |
|
3508 | 3543 | class TestLayerRepr: |
|
0 commit comments