Skip to content

Commit 7359e1b

Browse files
authored
Refactor to standardize pytest usage (#1844)
1 parent 9bcacc5 commit 7359e1b

36 files changed

+1751
-2265
lines changed

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def localversion_func(version: ScmVersion) -> str:
147147
"pytest>=6.0.0",
148148
"pytest-mock>=3.6.0",
149149
"pytest-rerunfailures>=13.0",
150-
"parameterized",
151150
"lm_eval==0.4.5",
152151
# test dependencies
153152
"beautifulsoup4~=4.12.3",

tests/custom_test.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

tests/data.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

tests/examples/utils.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,14 @@
1010
from bs4 import BeautifulSoup, ResultSet, Tag
1111
from cmarkgfm import github_flavored_markdown_to_html as gfm_to_html
1212

13-
from tests.testing_utils import run_cli_command
13+
from tests.testing_utils import requires_gpu, run_cli_command
1414

1515
_T = TypeVar("_T")
1616

1717

1818
def requires_gpu_count(num_required_gpus: int) -> pytest.MarkDecorator:
19-
"""
20-
Pytest decorator to skip based on number of available GPUs. This plays nicely with
21-
the CUDA_VISIBLE_DEVICES environment variable.
22-
"""
23-
import torch
24-
25-
num_gpus = torch.cuda.device_count()
26-
reason = f"{num_required_gpus} GPUs required, {num_gpus} GPUs detected"
27-
return pytest.mark.skipif(num_required_gpus > num_gpus, reason=reason)
19+
# Remove after #1801
20+
return requires_gpu(num_required_gpus)
2821

2922

3023
def requires_gpu_mem(required_amount: Union[int, float]) -> pytest.MarkDecorator:

tests/llmcompressor/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import pytest
77

8+
from llmcompressor.modifiers.factory import ModifierFactory
9+
810
try:
911
import wandb
1012
except Exception:
@@ -15,6 +17,12 @@
1517
os.environ["NM_TEST_LOG_DIR"] = "nm_temp_test_logs"
1618

1719

20+
@pytest.fixture
21+
def setup_modifier_factory():
22+
ModifierFactory.refresh()
23+
assert ModifierFactory._loaded, "ModifierFactory not loaded"
24+
25+
1826
def _get_files(directory: str) -> List[str]:
1927
list_filepaths = []
2028
for root, dirs, files in os.walk(directory):

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@
66
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
77
from llmcompressor.modifiers.awq.base import get_lowest_common_parent
88
from llmcompressor.modifiers.factory import ModifierFactory
9-
from tests.llmcompressor.modifiers.conf import setup_modifier_factory
109

1110

1211
@pytest.mark.unit
12+
@pytest.mark.usefixtures("setup_modifier_factory")
1313
def test_awq_is_registered():
1414
"""Ensure AWQModifier is registered in ModifierFactory"""
15-
16-
setup_modifier_factory()
17-
1815
modifier = ModifierFactory.create(
1916
type_="AWQModifier",
2017
allow_experimental=False,

tests/llmcompressor/modifiers/conf.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,6 @@
33
from torch.utils.data import DataLoader
44

55
from llmcompressor.core import Event, EventType, State
6-
from llmcompressor.modifiers.factory import ModifierFactory
7-
8-
9-
def setup_modifier_factory():
10-
ModifierFactory.refresh()
11-
assert ModifierFactory._loaded, "ModifierFactory not loaded"
126

137

148
class LifecyleTestingHarness:
Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,28 @@
1-
import unittest
2-
31
import pytest
42

53
from llmcompressor.modifiers.factory import ModifierFactory
64
from llmcompressor.modifiers.logarithmic_equalization.base import (
75
LogarithmicEqualizationModifier,
86
)
97
from llmcompressor.modifiers.smoothquant.base import SmoothQuantModifier
10-
from tests.llmcompressor.modifiers.conf import setup_modifier_factory
118

129

1310
@pytest.mark.unit
14-
class TestLogarithmicEqualizationIsRegistered(unittest.TestCase):
15-
def setUp(self):
16-
self.kwargs = dict(
17-
smoothing_strength=0.3,
18-
mappings=[(["layer1", "layer2"], "layer3")],
19-
)
20-
setup_modifier_factory()
21-
22-
def test_log_equalization_is_registered(self):
23-
modifier = ModifierFactory.create(
24-
type_="LogarithmicEqualizationModifier",
25-
allow_experimental=False,
26-
allow_registered=True,
27-
**self.kwargs,
28-
)
29-
30-
self.assertIsInstance(
31-
modifier,
32-
LogarithmicEqualizationModifier,
33-
"PyTorch LogarithmicEqualizationModifier not registered",
34-
)
11+
@pytest.mark.usefixtures("setup_modifier_factory")
12+
def test_logarithmic_equalization_is_registered():
13+
smoothing_strength = 0.3
14+
mappings = [(["layer1", "layer2"], "layer3")]
15+
modifier = ModifierFactory.create(
16+
type_="LogarithmicEqualizationModifier",
17+
allow_experimental=False,
18+
allow_registered=True,
19+
smoothing_strength=smoothing_strength,
20+
mappings=mappings,
21+
)
3522

36-
self.assertIsInstance(modifier, SmoothQuantModifier)
37-
self.assertEqual(modifier.smoothing_strength, self.kwargs["smoothing_strength"])
38-
self.assertEqual(modifier.mappings, self.kwargs["mappings"])
23+
assert isinstance(
24+
modifier, LogarithmicEqualizationModifier
25+
), "PyTorch LogarithmicEqualizationModifier not registered"
26+
assert isinstance(modifier, SmoothQuantModifier)
27+
assert modifier.smoothing_strength == smoothing_strength
28+
assert modifier.mappings == mappings
Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,22 @@
1-
import unittest
2-
31
import pytest
42

53
from llmcompressor.modifiers.factory import ModifierFactory
64
from llmcompressor.modifiers.obcq.base import SparseGPTModifier
7-
from tests.llmcompressor.modifiers.conf import setup_modifier_factory
85

96

107
@pytest.mark.unit
11-
class TestSparseGPTIsRegistered(unittest.TestCase):
12-
def setUp(self):
13-
self.kwargs = dict(
14-
sparsity=0.5,
15-
targets="__ALL_PRUNABLE__",
16-
)
17-
setup_modifier_factory()
18-
19-
def test_wanda_is_registered(self):
20-
type_ = ModifierFactory.create(
21-
type_="SparseGPTModifier",
22-
allow_experimental=False,
23-
allow_registered=True,
24-
**self.kwargs,
25-
)
8+
@pytest.mark.usefixtures("setup_modifier_factory")
9+
def test_sparse_gpt_is_registered():
10+
sparsity = 0.5
11+
targets = "__ALL_PRUNABLE__"
12+
type_ = ModifierFactory.create(
13+
type_="SparseGPTModifier",
14+
allow_experimental=False,
15+
allow_registered=True,
16+
sparsity=sparsity,
17+
targets=targets,
18+
)
2619

27-
self.assertIsInstance(
28-
type_,
29-
SparseGPTModifier,
30-
"PyTorch SparseGPTModifier not registered",
31-
)
20+
assert isinstance(
21+
type_, SparseGPTModifier
22+
), "PyTorch SparseGPTModifier not registered"
Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,22 @@
1-
import unittest
2-
31
import pytest
42

53
from llmcompressor.modifiers.factory import ModifierFactory
64
from llmcompressor.modifiers.pruning.wanda.base import WandaPruningModifier
7-
from tests.llmcompressor.modifiers.conf import setup_modifier_factory
85

96

107
@pytest.mark.unit
11-
class TestWandaIsRegistered(unittest.TestCase):
12-
def setUp(self):
13-
self.kwargs = dict(
14-
sparsity=0.5,
15-
targets="__ALL_PRUNABLE__",
16-
)
17-
setup_modifier_factory()
18-
19-
def test_wanda_is_registered(self):
20-
type_ = ModifierFactory.create(
21-
type_="WandaPruningModifier",
22-
allow_experimental=False,
23-
allow_registered=True,
24-
**self.kwargs,
25-
)
8+
@pytest.mark.usefixtures("setup_modifier_factory")
9+
def test_wanda_is_registered():
10+
sparsity = 0.5
11+
targets = "__ALL_PRUNABLE__"
12+
type_ = ModifierFactory.create(
13+
type_="WandaPruningModifier",
14+
allow_experimental=False,
15+
allow_registered=True,
16+
sparsity=sparsity,
17+
targets=targets,
18+
)
2619

27-
self.assertIsInstance(
28-
type_,
29-
WandaPruningModifier,
30-
"PyTorch WandaPruningModifier not registered",
31-
)
20+
assert isinstance(
21+
type_, WandaPruningModifier
22+
), "PyTorch WandaPruningModifier not registered"

0 commit comments

Comments
 (0)