Skip to content

Commit dbcee48

Browse files
committed
add testZstdLibrary
1 parent e51ac9d commit dbcee48

File tree

3 files changed

+96
-8
lines changed

3 files changed

+96
-8
lines changed

src/test/java/dev/zarr/zarrjava/ZarrTest.java

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,24 @@
44
import com.amazonaws.auth.AnonymousAWSCredentials;
55
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
66
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.github.luben.zstd.ZstdOutputStream;
78
import dev.zarr.zarrjava.store.FilesystemStore;
89
import dev.zarr.zarrjava.store.HttpStore;
910
import dev.zarr.zarrjava.store.S3Store;
1011
import dev.zarr.zarrjava.store.StoreHandle;
1112
import dev.zarr.zarrjava.utils.MultiArrayUtils;
13+
import dev.zarr.zarrjava.utils.Utils;
1214
import dev.zarr.zarrjava.v3.*;
1315
import dev.zarr.zarrjava.v3.codec.core.TransposeCodec;
1416
import org.junit.jupiter.api.Assertions;
1517
import org.junit.jupiter.api.BeforeAll;
1618
import org.junit.jupiter.api.Test;
1719
import org.junit.jupiter.params.ParameterizedTest;
20+
import org.junit.jupiter.params.provider.CsvSource;
1821
import org.junit.jupiter.params.provider.ValueSource;
1922

20-
import java.io.BufferedReader;
21-
import java.io.File;
22-
import java.io.IOException;
23-
import java.io.InputStreamReader;
23+
import java.io.*;
24+
import java.nio.ByteBuffer;
2425
import java.nio.file.Files;
2526
import java.nio.file.Path;
2627
import java.nio.file.Paths;
@@ -37,6 +38,7 @@ public class ZarrTest {
3738
final static Path TESTOUTPUT = Paths.get("testoutput");
3839
final static Path ZARRITA_WRITE_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_write.py");
3940
final static Path ZARRITA_READ_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_read.py");
41+
final static Path TEST_ZSTD_LIBRARY_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/test_zstd_library.py");
4042

4143
public static String pythonPath() {
4244
if (System.getProperty("os.name").startsWith("Windows")) {
@@ -91,10 +93,43 @@ public void testReadFromZarrita(String codec) throws IOException, ZarrException,
9193
Assertions.assertArrayEquals(expectedData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
9294
}
9395

96+
@ParameterizedTest
97+
@CsvSource({"0,true", "0,false", "5, true", "5, false"})
98+
public void testZstdLibrary(int clevel, boolean checksum) throws IOException, InterruptedException {
99+
String zstd_file = TESTOUTPUT + "/testZstdLibrary" + clevel + checksum + ".zstd";
100+
101+
ByteBuffer testBytes = ByteBuffer.allocate(1024);
102+
testBytes.putInt(42);
103+
104+
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
105+
ZstdOutputStream zstdStream = new ZstdOutputStream(outputStream, clevel);
106+
zstdStream.setChecksum(checksum);
107+
zstdStream.write(Utils.toArray(testBytes));
108+
zstdStream.close();
109+
ByteBuffer encodedBytes = ByteBuffer.wrap(outputStream.toByteArray());
110+
try (FileOutputStream fileOutputStream = new FileOutputStream(zstd_file)) {
111+
fileOutputStream.write(encodedBytes.array());
112+
}
113+
String command = pythonPath();
114+
ProcessBuilder pb = new ProcessBuilder(command, TEST_ZSTD_LIBRARY_PATH.toString(), zstd_file);
115+
Process process = pb.start();
116+
117+
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
118+
String line;
119+
while ((line = reader.readLine()) != null) {
120+
System.out.println(line);
121+
}
122+
BufferedReader readerErr = new BufferedReader(new InputStreamReader(process.getErrorStream()));
123+
while ((line = readerErr.readLine()) != null) {
124+
System.err.println(line);
125+
}
126+
int exitCode = process.waitFor();
127+
assert exitCode == 0;
128+
}
129+
94130
//TODO: add crc32c
95-
//Disabled "zstd": known issue
96131
@ParameterizedTest
97-
@ValueSource(strings = {"blosc", "gzip", "bytes", "transpose", "sharding_start", "sharding_end"})
132+
@ValueSource(strings = {"blosc", "gzip", "zstd", "bytes", "transpose", "sharding_start", "sharding_end"})
98133
public void testWriteToZarrita(String codec) throws IOException, ZarrException, InterruptedException {
99134
StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("write_to_zarrita", codec);
100135
ArrayMetadataBuilder builder = Array.metadataBuilder()
@@ -216,8 +251,31 @@ public void testCodecsWriteRead(String codec) throws IOException, ZarrException,
216251
Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
217252
}
218253

254+
@ParameterizedTest
255+
@CsvSource({"0,true", "0,false", "5, true", "5, false"})
256+
public void testZstdCodecReadWrite(int clevel, boolean checksum) throws ZarrException, IOException {
257+
int[] testData = new int[16 * 16 * 16];
258+
Arrays.setAll(testData, p -> p);
259+
260+
StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("testZstdCodecReadWrite", "checksum_" + checksum, "clevel_" + clevel);
261+
ArrayMetadataBuilder builder = Array.metadataBuilder()
262+
.withShape(16, 16, 16)
263+
.withDataType(DataType.UINT32)
264+
.withChunkShape(2, 4, 8)
265+
.withFillValue(0)
266+
.withCodecs(c -> c.withZstd(clevel, checksum));
267+
Array writeArray = Array.create(storeHandle, builder.build());
268+
writeArray.write(ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{16, 16, 16}, testData));
269+
270+
Array readArray = Array.open(storeHandle);
271+
ucar.ma2.Array result = readArray.read();
272+
273+
Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
274+
275+
}
276+
219277
@Test
220-
public void testCodecTranspose() throws IOException, ZarrException, InterruptedException {
278+
public void testTransposeCodec() throws ZarrException {
221279
ucar.ma2.Array testData = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{2, 3, 3}, new int[]{
222280
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
223281
ucar.ma2.Array testDataTransposed120 = ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{3, 3, 2}, new int[]{
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import struct
2+
import sys
3+
4+
import zstandard as zstd
5+
6+
zstd_file = sys.argv[1]
7+
8+
9+
def compress_data_to_file(file_path, integer_value):
10+
data = struct.pack('>i', integer_value)
11+
compressor = zstd.ZstdCompressor(level=0)
12+
compressed_data = compressor.compress(data)
13+
with open(file_path, 'wb') as file:
14+
file.write(compressed_data)
15+
16+
17+
def decompress_zstd_file(file_path):
18+
with open(file_path, 'rb') as file:
19+
compressed_data = file.read()
20+
decompressor = zstd.ZstdDecompressor() # is with FORMAT_ZSTD1
21+
22+
return decompressor.decompress(compressed_data)
23+
24+
25+
# for comparison
26+
compress_data_to_file(zstd_file + "_", 42)
27+
28+
decompressed_data = decompress_zstd_file(zstd_file)
29+
int_value = int.from_bytes(decompressed_data[:4], byteorder='big')
30+
assert int_value == 42

src/test/java/dev/zarr/zarrjava/zarrita_read.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
elif codec_string == "gzip":
1111
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.gzip_codec()]
1212
elif codec_string == "zstd":
13-
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec()]
13+
codec = [zarrita.codecs.bytes_codec(), zarrita.codecs.zstd_codec(checksum=True)]
1414
elif codec_string == "bytes":
1515
codec = [zarrita.codecs.bytes_codec()]
1616
elif codec_string == "transpose":

0 commit comments

Comments
 (0)