Skip to content

Commit 9e975d3

Browse files
committed
WIP: janice network issues
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a21648d commit 9e975d3

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class QuIPModifier(Modifier):
4545
)
4646
randomize: bool = Field(default=False, exclude=True)
4747
learnable: bool = Field(default=False, exclude=True)
48+
precision:
4849
ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True)
4950

5051
# optional override for more fine-grained control

tests/llmcompressor/modifiers/transform/test_correctness.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import pytest
23
import torch
34
from transformers import AutoModelForCausalLM
@@ -8,16 +9,20 @@
89

910

1011
@requires_gpu
12+
# @pytest.mark.skipif(
13+
# (not os.getenv("HF_TOKEN")),
14+
# reason="Skipping tracing tests requiring gated model access",
15+
# )
1116
@pytest.mark.parametrize(
1217
"dtype,exp_mse",
1318
[
14-
(torch.bfloat16, 1e-2),
15-
(torch.float32, 1e-9),
19+
(torch.bfloat16, 5e-3),
20+
(torch.float32, 5e-11),
1621
],
1722
)
1823
def test_apply_correctness(dtype, exp_mse):
1924
model = AutoModelForCausalLM.from_pretrained(
20-
"meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype
25+
"meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=dtype
2126
)
2227
state = State(model=model)
2328
modifier = QuIPModifier(transform_type="random-hadamard")
@@ -32,4 +37,5 @@ def test_apply_correctness(dtype, exp_mse):
3237
with torch.no_grad():
3338
output = model(**input)
3439

40+
print(torch.nn.MSELoss()(output.logits, true_output.logits))
3541
assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse

0 commit comments

Comments
 (0)