|
1 | 1 | import itertools |
| 2 | +import subprocess |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 | import pytest |
5 | 6 |
|
6 | | -try: |
7 | | - from numcodecs.zstd import Zstd |
8 | | -except ImportError: # pragma: no cover |
9 | | - pytest.skip("numcodecs.zstd not available", allow_module_level=True) |
10 | | - |
11 | | - |
12 | 7 | from numcodecs.tests.common import ( |
13 | 8 | check_backwards_compatibility, |
14 | 9 | check_config, |
|
17 | 12 | check_err_encode_object_buffer, |
18 | 13 | check_repr, |
19 | 14 | ) |
| 15 | +from numcodecs.zstd import Zstd |
20 | 16 |
|
21 | 17 | codecs = [ |
22 | 18 | Zstd(), |
@@ -90,3 +86,73 @@ def test_native_functions(): |
90 | 86 | assert Zstd.default_level() == 3 |
91 | 87 | assert Zstd.min_level() == -131072 |
92 | 88 | assert Zstd.max_level() == 22 |
| 89 | + |
| 90 | + |
| 91 | +def test_streaming_decompression(): |
| 92 | + # Test input frames with unknown frame content size |
| 93 | + codec = Zstd() |
| 94 | + |
| 95 | + # If the zstd command line interface is available, check the bytes |
| 96 | + cli = zstd_cli_available() |
| 97 | + if cli: |
| 98 | + view_zstd_streaming_bytes() |
| 99 | + |
| 100 | + # Encode bytes directly that were the result of streaming compression |
| 101 | + bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!' |
| 102 | + dec = codec.decode(bytes_val) |
| 103 | + dec_expected = b'Hello World!' |
| 104 | + assert dec == dec_expected |
| 105 | + if cli: |
| 106 | + assert bytes_val == generate_zstd_streaming_bytes(dec_expected) |
| 107 | + assert dec_expected == generate_zstd_streaming_bytes(bytes_val, decompress=True) |
| 108 | + |
| 109 | + # Two consecutive frames given as input |
| 110 | + bytes2 = bytes(bytearray(bytes_val * 2)) |
| 111 | + dec2 = codec.decode(bytes2) |
| 112 | + dec2_expected = b'Hello World!Hello World!' |
| 113 | + assert dec2 == dec2_expected |
| 114 | + if cli: |
| 115 | + assert dec2_expected == generate_zstd_streaming_bytes(bytes2, decompress=True) |
| 116 | + |
| 117 | + # Single long frame that decompresses to a large output |
| 118 | + bytes3 = b'(\xb5/\xfd\x00X$\x02\x00\xa4\x03ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz\x01\x00:\xfc\xdfs\x05\x05L\x00\x00\x08s\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08k\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08c\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08[\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08S\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08K\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08C\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08u\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08m\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08e\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08]\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08U\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08M\x01\x00\xfc\xff9\x10\x02M\x00\x00\x08E\x01\x00\xfc\x7f\x1d\x08\x01' |
| 119 | + dec3 = codec.decode(bytes3) |
| 120 | + dec3_expected = b'ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz' * 1024 * 32 |
| 121 | + assert dec3 == dec3_expected |
| 122 | + if cli: |
| 123 | + assert bytes3 == generate_zstd_streaming_bytes(dec3_expected) |
| 124 | + assert dec3_expected == generate_zstd_streaming_bytes(bytes3, decompress=True) |
| 125 | + |
| 126 | + # Garbage input results in an error |
| 127 | + bytes4 = bytes(bytearray([0, 0, 0, 0, 0, 0, 0, 0])) |
| 128 | + with pytest.raises(RuntimeError, match='Zstd decompression error: invalid input data'): |
| 129 | + codec.decode(bytes4) |
| 130 | + |
| 131 | + |
| 132 | +def generate_zstd_streaming_bytes(input: bytes, *, decompress: bool = False) -> bytes: |
| 133 | + """ |
| 134 | + Use the zstd command line interface to compress or decompress bytes in streaming mode. |
| 135 | + """ |
| 136 | + if decompress: |
| 137 | + args = ["-d"] |
| 138 | + else: |
| 139 | + args = [] |
| 140 | + |
| 141 | + p = subprocess.run(["zstd", "--no-check", *args], input=input, capture_output=True) |
| 142 | + return p.stdout |
| 143 | + |
| 144 | + |
| 145 | +def view_zstd_streaming_bytes(): |
| 146 | + bytes_val = generate_zstd_streaming_bytes(b"Hello world!") |
| 147 | + print(f" bytes_val = {bytes_val}") |
| 148 | + |
| 149 | + bytes3 = generate_zstd_streaming_bytes( |
| 150 | + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz" * 1024 * 32 |
| 151 | + ) |
| 152 | + print(f" bytes3 = {bytes3}") |
| 153 | + |
| 154 | + |
| 155 | +def zstd_cli_available() -> bool: |
| 156 | + return not subprocess.run( |
| 157 | + ["zstd", "-V"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL |
| 158 | + ).returncode |
0 commit comments