Skip to content

Commit dec25f5

Browse files
Update test code for the GraLoRA method
1 parent 9431502 commit dec25f5

File tree

2 files changed

+50
-46
lines changed

2 files changed

+50
-46
lines changed

tests/test_custom_models.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
C3AConfig,
3939
DeloraConfig,
4040
FourierFTConfig,
41+
GraloraConfig,
4142
HRAConfig,
4243
IA3Config,
4344
LNTuningConfig,
@@ -666,6 +667,25 @@
666667
"init_weights": True,
667668
},
668669
),
670+
###########
671+
# GraLoRA #
672+
###########
673+
("Vanilla MLP 1 GraLoRA", "MLP", GraloraConfig, {"target_modules": "lin0"}),
674+
("Vanilla MLP 2 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0"]}),
675+
("Vanilla MLP 3 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin1"]}),
676+
("Vanilla MLP 4 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0", "lin1"]}),
677+
(
678+
"Vanilla MLP 5 GraLoRA",
679+
"MLP",
680+
GraloraConfig,
681+
{"target_modules": ["lin0"], "modules_to_save": ["lin1"]},
682+
),
683+
(
684+
"Embedding + transformers Conv1D 1 GraLoRA",
685+
"EmbConv1D",
686+
GraloraConfig,
687+
{"target_modules": ["conv1d"], "gralora_k": 1},
688+
),
669689
##########
670690
# VBLoRA #
671691
##########
@@ -979,6 +999,20 @@
979999
{"n_frequency": 10, "target_modules": ["lin0"]},
9801000
{"n_frequency": 10, "target_modules": ["lin1"]},
9811001
),
1002+
(
1003+
"GraLoRA Same",
1004+
"gralora",
1005+
GraloraConfig,
1006+
{"target_modules": ["lin0"], "init_weights": False},
1007+
{"target_modules": ["lin0"], "init_weights": False},
1008+
),
1009+
(
1010+
"GraLoRA Different",
1011+
"gralora",
1012+
GraloraConfig,
1013+
{"target_modules": ["lin0"], "init_weights": False},
1014+
{"target_modules": ["lin1"], "init_weights": False},
1015+
),
9821016
(
9831017
"SHiRA Same",
9841018
"shira",
@@ -1165,6 +1199,7 @@
11651199
VeraConfig: "vera_lambda_",
11661200
RandLoraConfig: "randlora_",
11671201
FourierFTConfig: "fourierft_",
1202+
GraloraConfig: "gralora_",
11681203
C3AConfig: "c3a_",
11691204
HRAConfig: "hra_",
11701205
ShiraConfig: "shira_",
@@ -3089,12 +3124,12 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self):
30893124
cancelled_B = module.lora_B["cancelled"].weight.data
30903125

30913126
# The weights should be approximately zero (they cancel out)
3092-
assert torch.allclose(cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5), (
3093-
f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}"
3094-
)
3095-
assert torch.allclose(cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5), (
3096-
f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}"
3097-
)
3127+
assert torch.allclose(
3128+
cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5
3129+
), f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}"
3130+
assert torch.allclose(
3131+
cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5
3132+
), f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}"
30983133

30993134
def test_add_weighted_adapter_negative_weight_with_different_scaling(self):
31003135
# Test negative weights with different scaling factors (lora_alpha)
@@ -3500,9 +3535,9 @@ def test_multirank_2(self):
35003535
if isinstance(module, BaseTunerLayer):
35013536
rank_expected = rank_pattern.get(key, r)
35023537
rank_current = module.lora_A[adapter].weight.shape[0]
3503-
assert rank_current == rank_expected, (
3504-
f"Rank {rank_current} is not equal to expected {rank_expected}"
3505-
)
3538+
assert (
3539+
rank_current == rank_expected
3540+
), f"Rank {rank_current} is not equal to expected {rank_expected}"
35063541

35073542

35083543
class TestLayerRepr:

tests/test_gralora.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_gralora_parameter_shapes(self, mlp_gralora_hybrid):
112112
in_features = module.in_features
113113
out_features = module.out_features
114114
k = 4
115-
gralora_rank = 16 - 4 # r - hybrid_r
115+
gralora_rank = 16
116116

117117
# Check GraLoRA block shapes
118118
# Each block has full gralora_rank, not gralora_rank // k
@@ -203,7 +203,7 @@ def test_gralora_pure_vs_hybrid_params(self):
203203
mlp_hybrid = MLP()
204204
config_hybrid = GraloraConfig(
205205
target_modules=["lin1", "lin2"],
206-
r=16,
206+
r=12,
207207
gralora_k=4,
208208
hybrid_r=4,
209209
)
@@ -217,9 +217,9 @@ def count_trainable_params(model):
217217

218218
# Pure and hybrid should have same total parameters (r is constant)
219219
# but distributed differently between block-diagonal and full-rank components
220-
assert params_pure == params_hybrid, (
221-
f"Pure ({params_pure}) and Hybrid ({params_hybrid}) should have same parameter count"
222-
)
220+
assert (
221+
params_pure == params_hybrid
222+
), f"Pure ({params_pure}) and Hybrid ({params_hybrid}) should have same parameter count"
223223

224224
# Check that hybrid has general components
225225
has_general = False
@@ -444,7 +444,7 @@ def test_gralora_rank_divisibility_check(self):
444444
hybrid_r=0,
445445
)
446446

447-
with pytest.raises(AssertionError, match="r should be divisible by gralora_k"):
447+
with pytest.raises(ValueError, match="r should be divisible by gralora_k"):
448448
get_peft_model(mlp, config)
449449

450450
def test_gralora_trainable_parameters_only(self, mlp_gralora_hybrid):
@@ -827,37 +827,6 @@ def test_gralora_unload_without_merge(self):
827827
# Should match base model output (no merge)
828828
assert torch.allclose(base_output, unloaded_output, atol=1e-5)
829829

830-
def test_gralora_get_peft_config_as_dict(self):
831-
"""Test get_peft_config_as_dict method"""
832-
torch.manual_seed(0)
833-
mlp = MLP()
834-
config = GraloraConfig(
835-
target_modules=["lin1"],
836-
r=8,
837-
gralora_k=2,
838-
hybrid_r=4,
839-
gralora_alpha=16,
840-
)
841-
model = get_peft_model(mlp, config)
842-
843-
config_dict = model.get_peft_config_as_dict(inference=False)
844-
845-
assert "default" in config_dict
846-
assert config_dict["default"]["r"] == 8
847-
assert config_dict["default"]["gralora_k"] == 2
848-
assert config_dict["default"]["hybrid_r"] == 4
849-
850-
def test_gralora_get_peft_config_as_dict_inference_mode(self):
851-
"""Test get_peft_config_as_dict with inference=True"""
852-
torch.manual_seed(0)
853-
mlp = MLP()
854-
config = GraloraConfig(target_modules=["lin1"], r=8, gralora_k=2)
855-
model = get_peft_model(mlp, config)
856-
857-
config_dict = model.get_peft_config_as_dict(inference=True)
858-
859-
assert config_dict["default"]["inference_mode"] is True
860-
861830
def test_gralora_merge_with_hybrid_component(self):
862831
"""Test that merge works correctly with hybrid component"""
863832
torch.manual_seed(0)

0 commit comments

Comments
 (0)