Skip to content

Commit 006899c

Browse files
authored
[Tests] Fix test case; update structure (#1375)
# Summary - Fix device_map and set torch.dtype for the given model - Move tests to a folder which makes more sense
1 parent 90c4075 commit 006899c

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def _test_oneshot_and_finetune(self):
6464
)
6565
# model is first sparsified, then finetuned, both should have the same sparsity
6666
assert config_sparse_applied["global_sparsity"] == pytest.approx(
67-
config_finetune_applied["global_sparsity"],
68-
abs=1e-5
67+
config_finetune_applied["global_sparsity"], abs=1e-5
6968
)
7069

7170
def tearDown(self):

tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_oneshot_and_finetune_with_tokenizer(self):
3636
self.model,
3737
)
3838
model_loaded = AutoModelForCausalLM.from_pretrained(
39-
self.model, device_map="auto"
39+
self.model, device_map="cuda:0", torch_dtype="auto"
4040
)
4141

4242
dataset_loaded = load_dataset(
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
cadence: "nightly"
2+
test_type: "regression"
3+
model: "nm-testing/llama2.c-stories15M"
4+
dataset: open_platypus

tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py renamed to tests/llmcompressor/transformers/obcq/test_oneshot_with_modifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from tests.testing_utils import parse_params, requires_gpu
1010

11-
CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_generic"
11+
CONFIGS_DIRECTORY = (
12+
"tests/llmcompressor/transformers/obcq/obcq_configs/sparsity_generic"
13+
)
1214

1315

1416
@pytest.mark.integration

0 commit comments

Comments
 (0)