Skip to content

Commit 13b5c0b

Browse files
authored
Add sparsity structure enum (#197)
1 parent 07abbf3 commit 13b5c0b

File tree

3 files changed

+130
-2
lines changed

3 files changed

+130
-2
lines changed

src/compressed_tensors/config/base.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from enum import Enum
15+
from enum import Enum, unique
1616
from typing import List, Optional
1717

1818
from compressed_tensors.registry import RegistryMixin
1919
from pydantic import BaseModel
2020

2121

22-
__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
22+
__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
2323

2424

25+
@unique
2526
class CompressionFormat(Enum):
2627
dense = "dense"
2728
sparse_bitmask = "sparse-bitmask"
@@ -32,6 +33,63 @@ class CompressionFormat(Enum):
3233
marlin_24 = "marlin-24"
3334

3435

36+
@unique
37+
class SparsityStructure(Enum):
38+
"""
39+
An enumeration to represent different sparsity structures.
40+
41+
Attributes
42+
----------
43+
TWO_FOUR : str
44+
Represents a 2:4 sparsity structure.
45+
ZERO_ZERO : str
46+
Represents a 0:0 sparsity structure.
47+
UNSTRUCTURED : str
48+
Represents an unstructured sparsity structure.
49+
50+
Examples
51+
--------
52+
>>> SparsityStructure('2:4')
53+
<SparsityStructure.TWO_FOUR: '2:4'>
54+
55+
>>> SparsityStructure('unstructured')
56+
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
57+
58+
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
59+
True
60+
61+
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
62+
True
63+
64+
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
65+
True
66+
67+
>>> SparsityStructure('invalid')
68+
Traceback (most recent call last):
69+
...
70+
ValueError: invalid is not a valid SparsityStructure
71+
"""
72+
73+
TWO_FOUR = "2:4"
74+
UNSTRUCTURED = "unstructured"
75+
ZERO_ZERO = "0:0"
76+
77+
def __new__(cls, value):
78+
obj = object.__new__(cls)
79+
obj._value_ = value.lower() if value is not None else value
80+
return obj
81+
82+
@classmethod
83+
def _missing_(cls, value):
84+
# Handle None and case-insensitive values
85+
if value is None:
86+
return cls.UNSTRUCTURED
87+
for member in cls:
88+
if member.value == value.lower():
89+
return member
90+
raise ValueError(f"{value} is not a valid {cls.__name__}")
91+
92+
3593
class SparsityCompressionConfig(RegistryMixin, BaseModel):
3694
"""
3795
Base data class for storing sparsity compression parameters

tests/test_configs/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

tests/test_configs/test_base.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from compressed_tensors.config import SparsityStructure
17+
18+
19+
def test_sparsity_structure_valid_cases():
20+
assert (
21+
SparsityStructure("2:4") == SparsityStructure.TWO_FOUR
22+
), "Failed to match '2:4' with TWO_FOUR"
23+
assert (
24+
SparsityStructure("unstructured") == SparsityStructure.UNSTRUCTURED
25+
), "Failed to match 'unstructured' with UNSTRUCTURED"
26+
assert (
27+
SparsityStructure("UNSTRUCTURED") == SparsityStructure.UNSTRUCTURED
28+
), "Failed to match 'UNSTRUCTURED' with UNSTRUCTURED"
29+
assert (
30+
SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
31+
), "Failed to match None with UNSTRUCTURED"
32+
33+
34+
def test_sparsity_structure_invalid_case():
35+
with pytest.raises(ValueError, match="invalid is not a valid SparsityStructure"):
36+
SparsityStructure("invalid")
37+
38+
39+
def test_sparsity_structure_case_insensitivity():
40+
assert (
41+
SparsityStructure("2:4") == SparsityStructure.TWO_FOUR
42+
), "Failed to match '2:4' with TWO_FOUR"
43+
assert (
44+
SparsityStructure("2:4".upper()) == SparsityStructure.TWO_FOUR
45+
), "Failed to match '2:4'.upper() with TWO_FOUR"
46+
assert (
47+
SparsityStructure("unstructured".upper()) == SparsityStructure.UNSTRUCTURED
48+
), "Failed to match 'unstructured'.upper() with UNSTRUCTURED"
49+
assert (
50+
SparsityStructure("UNSTRUCTURED".lower()) == SparsityStructure.UNSTRUCTURED
51+
), "Failed to match 'UNSTRUCTURED'.lower() with UNSTRUCTURED"
52+
53+
54+
def test_sparsity_structure_default_case():
55+
assert (
56+
SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
57+
), "Failed to match None with UNSTRUCTURED"

0 commit comments

Comments
 (0)