-
Notifications
You must be signed in to change notification settings - Fork 457
Expand file tree
/
Copy pathtest_base.py
More file actions
28 lines (24 loc) · 994 Bytes
/
test_base.py
File metadata and controls
28 lines (24 loc) · 994 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import pytest
from llmcompressor.modifiers.factory import ModifierFactory
from llmcompressor.modifiers.logarithmic_equalization.base import (
LogarithmicEqualizationModifier,
)
from llmcompressor.modifiers.transform.smoothquant.base import SmoothQuantModifier
@pytest.mark.unit
@pytest.mark.usefixtures("setup_modifier_factory")
def test_logarithmic_equalization_is_registered():
smoothing_strength = 0.3
mappings = [(["layer1", "layer2"], "layer3")]
modifier = ModifierFactory.create(
type_="LogarithmicEqualizationModifier",
allow_experimental=False,
allow_registered=True,
smoothing_strength=smoothing_strength,
mappings=mappings,
)
assert isinstance(modifier, LogarithmicEqualizationModifier), (
"PyTorch LogarithmicEqualizationModifier not registered"
)
assert isinstance(modifier, SmoothQuantModifier)
assert modifier.smoothing_strength == smoothing_strength
assert modifier.mappings == mappings