Skip to content

Commit 2d6d2f3

Browse files
committed
feat: add input validation in lossless source coding classes
1 parent e597f54 commit 2d6d2f3

File tree

8 files changed

+52
-3
lines changed

8 files changed

+52
-3
lines changed

src/komm/_lossless_coding/LempelZiv77Code.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy.typing as npt
55
from tqdm import tqdm
66

7+
from .._util.validators import validate_integer_range
78
from .util import integer_to_symbols, symbols_to_integer
89

910
Token = tuple[int, int, int]
@@ -100,7 +101,7 @@ def source_to_tokens(self, source: npt.ArrayLike) -> list[Token]:
100101
>>> lz77.source_to_tokens([0, 0, 1, 0, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 2])
101102
[(8, 2, 1), (7, 3, 2), (6, 7, 2)]
102103
"""
103-
source = np.asarray(source, dtype=int)
104+
source = validate_integer_range(source, high=self.source_cardinality)
104105
ss, ls = self.search_size, self.lookahead_size
105106
buffer = bytes(self.search_buffer + source.tolist())
106107
tokens: list[Token] = []

src/komm/_lossless_coding/LempelZiv78Code.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy.typing as npt
66
from tqdm import tqdm
77

8+
from komm._util.validators import validate_integer_range
9+
810
from .util import Word, integer_to_symbols, symbols_to_integer
911

1012
Token = tuple[int, int]
@@ -46,7 +48,7 @@ def source_to_tokens(self, source: npt.ArrayLike) -> list[Token]:
4648
>>> lz78.source_to_tokens([1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0])
4749
[(0, 1), (0, 0), (1, 1), (2, 1), (4, 0), (2, 0)]
4850
"""
49-
source = np.asarray(source)
51+
source = validate_integer_range(source, high=self.source_cardinality)
5052
dictionary: dict[Word, int] = {(): 0}
5153
tokens: list[Token] = []
5254
word: Word = ()

src/komm/_lossless_coding/LempelZivWelchCode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy.typing as npt
66
from tqdm import tqdm
77

8+
from komm._util.validators import validate_integer_range
9+
810
from .util import Word, integer_to_symbols, symbols_to_integer
911

1012

@@ -54,7 +56,7 @@ def encode(self, input: npt.ArrayLike) -> npt.NDArray[np.integer]:
5456
array([0, 2, 3, 4, 5])
5557
"""
5658
calX, calY = self.source_cardinality, self.target_cardinality
57-
input = np.asarray(input)
59+
input = validate_integer_range(input, high=calX)
5860
dictionary: dict[Word, int] = {(s,): s for s in range(calX)}
5961
output: list[int] = []
6062

src/komm/_lossless_coding/VariableToFixedCode.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,11 @@ def encode(self, input: npt.ArrayLike) -> Array1D[np.integer]:
360360
>>> code.encode([1, 0, 0]) # Incomplete input, completed as 1|000
361361
array([1, 1, 0, 0])
362362
363+
>>> code.encode([0, 7, 0, 0]) # 07 is not a valid source word
364+
Traceback (most recent call last):
365+
...
366+
ValueError: input contains invalid word
367+
363368
>>> code = komm.VariableToFixedCode.from_sourcewords(2, [(0, 0), (0, 1)])
364369
>>> code.encode([1, 0, 0, 0]) # Code is not fully covering
365370
Traceback (most recent call last):

src/komm/_util/validators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,15 @@ def validate_transition_matrix(
5757
f"({value.shape[0]}, {value.shape[1]}))"
5858
)
5959
return value
60+
61+
62+
def validate_integer_range(
63+
value: npt.ArrayLike,
64+
*,
65+
low: int = 0,
66+
high: int = 2,
67+
) -> npt.NDArray[np.integer]:
68+
value = np.asarray(value, dtype=int)
69+
if not (np.all(value >= low) and np.all(value < high)):
70+
raise ValueError(f"input contains invalid entries (expected in [{low}:{high}))")
71+
return value

tests/lossless_coding/test_lz77.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,12 @@ def test_lz77_zeros(k):
140140
code = komm.LempelZiv77Code(window_size=32, lookahead_size=8, source_cardinality=2)
141141
source = np.zeros(k, dtype=int)
142142
np.testing.assert_equal(code.decode(code.encode(source)), source)
143+
144+
145+
def test_lz77_invalid_input():
146+
code = komm.LempelZiv77Code(window_size=32, lookahead_size=8, source_cardinality=27)
147+
code.encode([0, 10, 26])
148+
with pytest.raises(ValueError, match="invalid entries"):
149+
code.encode([0, 10, 27])
150+
with pytest.raises(ValueError, match="invalid entries"):
151+
code.encode([-1, 10, 26])

tests/lossless_coding/test_lz78.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,12 @@ def test_lz78_worst_case(k):
149149
assert len(message) == (k - 1) * 2 ** (k + 1) + 2
150150
assert len(compressed) == len_compressed(2 ** (k + 1) - 2, 2)
151151
np.testing.assert_equal(code.decode(compressed), message)
152+
153+
154+
def test_lz78_invalid_input():
155+
code = komm.LempelZiv78Code(source_cardinality=27)
156+
code.encode([0, 10, 26])
157+
with pytest.raises(ValueError, match="invalid entries"):
158+
code.encode([0, 10, 27])
159+
with pytest.raises(ValueError, match="invalid entries"):
160+
code.encode([-1, 10, 26])

tests/lossless_coding/test_lzw.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,12 @@ def test_lzw_encode_decode(source_cardinality, target_cardinality):
6767
# Random message
6868
message = np.random.randint(0, source_cardinality, 1000)
6969
np.testing.assert_equal(code.decode(code.encode(message)), message)
70+
71+
72+
def test_lzw_invalid_input():
73+
code = komm.LempelZivWelchCode(source_cardinality=27)
74+
code.encode([0, 10, 26])
75+
with pytest.raises(ValueError, match="invalid entries"):
76+
code.encode([0, 10, 27])
77+
with pytest.raises(ValueError, match="invalid entries"):
78+
code.encode([-1, 10, 26])

0 commit comments

Comments
 (0)