Skip to content

Commit 26964dc

Browse files
committed
fix Zstd compression and decompression
1 parent dbcee48 commit 26964dc

File tree

3 files changed

+100
-24
lines changed

3 files changed

+100
-24
lines changed

src/main/java/dev/zarr/zarrjava/v3/codec/core/ZstdCodec.java

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import com.fasterxml.jackson.annotation.JsonCreator;
44
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import com.github.luben.zstd.Zstd;
6+
import com.github.luben.zstd.ZstdCompressCtx;
57
import com.github.luben.zstd.ZstdInputStream;
68
import com.github.luben.zstd.ZstdOutputStream;
79
import dev.zarr.zarrjava.ZarrException;
@@ -37,30 +39,40 @@ private void copy(InputStream inputStream, OutputStream outputStream) throws IOE
3739
}
3840

3941
@Override
40-
public ByteBuffer decode(ByteBuffer chunkBytes)
41-
throws ZarrException {
42-
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdInputStream inputStream = new ZstdInputStream(
43-
new ByteArrayInputStream(Utils.toArray(chunkBytes)))) {
44-
copy(inputStream, outputStream);
45-
inputStream.close();
46-
return ByteBuffer.wrap(outputStream.toByteArray());
47-
} catch (IOException ex) {
48-
throw new ZarrException("Error in decoding zstd.", ex);
42+
public ByteBuffer decode(ByteBuffer compressedBytes) throws ZarrException {
43+
// Extract the byte array from the ByteBuffer
44+
byte[] compressedArray = new byte[compressedBytes.remaining()];
45+
compressedBytes.get(compressedArray);
46+
47+
// Determine the original size (optional: you might need to store the original size separately)
48+
long originalSize = Zstd.decompressedSize(compressedArray);
49+
if (originalSize == 0) {
50+
throw new ZarrException("Failed to get decompressed size");
4951
}
52+
53+
// Create a buffer for the decompressed data
54+
byte[] decompressed = new byte[(int) originalSize];
55+
56+
// Perform decompression
57+
long bytesDecompressed = Zstd.decompress(decompressed, compressedArray);
58+
if (bytesDecompressed != originalSize) {
59+
throw new ZarrException("Decompression failed, incorrect decompressed size");
60+
}
61+
62+
// Wrap the decompressed byte array into a ByteBuffer
63+
return ByteBuffer.wrap(decompressed);
5064
}
5165

5266
@Override
53-
public ByteBuffer encode(ByteBuffer chunkBytes)
54-
throws ZarrException {
55-
try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ZstdOutputStream zstdStream = new ZstdOutputStream(
56-
outputStream, configuration.level).setChecksum(
57-
configuration.checksum)) {
58-
zstdStream.write(Utils.toArray(chunkBytes));
59-
zstdStream.close();
60-
return ByteBuffer.wrap(outputStream.toByteArray());
61-
} catch (IOException ex) {
62-
throw new ZarrException("Error in encoding zstd.", ex);
67+
public ByteBuffer encode(ByteBuffer chunkBytes) throws ZarrException {
68+
byte[] arr = chunkBytes.array();
69+
byte[] compressed;
70+
try (ZstdCompressCtx ctx = new ZstdCompressCtx()) {
71+
ctx.setLevel(configuration.level);
72+
ctx.setChecksum(configuration.checksum);
73+
compressed = ctx.compress(arr);
6374
}
75+
return ByteBuffer.wrap(compressed);
6476
}
6577

6678
@Override

src/test/java/dev/zarr/zarrjava/ZarrTest.java

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import com.amazonaws.auth.AnonymousAWSCredentials;
55
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
66
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.github.luben.zstd.ZstdCompressCtx;
8+
import com.github.luben.zstd.ZstdInputStream;
79
import com.github.luben.zstd.ZstdOutputStream;
810
import dev.zarr.zarrjava.store.FilesystemStore;
911
import dev.zarr.zarrjava.store.HttpStore;
@@ -19,7 +21,11 @@
1921
import org.junit.jupiter.params.ParameterizedTest;
2022
import org.junit.jupiter.params.provider.CsvSource;
2123
import org.junit.jupiter.params.provider.ValueSource;
24+
import com.github.luben.zstd.Zstd;
25+
import ucar.ma2.MAMath;
2226

27+
import java.io.FileOutputStream;
28+
import java.nio.ByteBuffer;
2329
import java.io.*;
2430
import java.nio.ByteBuffer;
2531
import java.nio.file.Files;
@@ -36,9 +42,11 @@ public class ZarrTest {
3642

3743
final static Path TESTDATA = Paths.get("testdata");
3844
final static Path TESTOUTPUT = Paths.get("testoutput");
39-
final static Path ZARRITA_WRITE_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_write.py");
40-
final static Path ZARRITA_READ_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/zarrita_read.py");
41-
final static Path TEST_ZSTD_LIBRARY_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/test_zstd_library.py");
45+
final static Path TEST_PATH = Paths.get("src/test/java/dev/zarr/zarrjava/");
46+
47+
final static Path ZARRITA_WRITE_PATH = TEST_PATH.resolve("zarrita_write.py");
48+
final static Path ZARRITA_READ_PATH = TEST_PATH.resolve("zarrita_read.py");
49+
final static Path TEST_ZSTD_LIBRARY_PATH = TEST_PATH.resolve("test_zstd_library.py");
4250

4351
public static String pythonPath() {
4452
if (System.getProperty("os.name").startsWith("Windows")) {
@@ -93,6 +101,49 @@ public void testReadFromZarrita(String codec) throws IOException, ZarrException,
93101
Assertions.assertArrayEquals(expectedData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT));
94102
}
95103

104+
private void copy(InputStream inputStream, OutputStream outputStream) throws IOException {
105+
byte[] buffer = new byte[4096];
106+
int len;
107+
while ((len = inputStream.read(buffer)) > 0) {
108+
outputStream.write(buffer, 0, len);
109+
}
110+
}
111+
112+
@CsvSource({"0,true", "0,false", "5, true", "5, false"})
113+
@ParameterizedTest
114+
public void testZstdLibrary2(int clevel, boolean checksumFlag) throws IOException, InterruptedException, ZarrException {
115+
//compress using ZstdCompressCtx
116+
int number = 123456;
117+
byte[] src = ByteBuffer.allocate(4).putInt(number).array();
118+
byte[] compressed;
119+
try (ZstdCompressCtx ctx = new ZstdCompressCtx()) {
120+
ctx.setLevel(clevel);
121+
ctx.setChecksum(checksumFlag);
122+
compressed = ctx.compress(src);
123+
}
124+
//decompress with Zstd.decompress
125+
long originalSize = Zstd.decompressedSize(compressed);
126+
byte[] decompressed = Zstd.decompress(compressed, (int) originalSize);
127+
Assertions.assertEquals(number, ByteBuffer.wrap(decompressed).getInt());
128+
129+
//write compressed to file
130+
String compressedDataPath =TESTOUTPUT.resolve("compressed" + clevel + checksumFlag + ".bin").toString();
131+
try (FileOutputStream fos = new FileOutputStream(compressedDataPath)) {
132+
fos.write(compressed);
133+
}
134+
135+
//decompress in python
136+
Process process = new ProcessBuilder(
137+
pythonPath(),
138+
TEST_PATH.resolve("decompress_print.py").toString(),
139+
compressedDataPath,
140+
Integer.toString(number)
141+
).start();
142+
int exitCode = process.waitFor();
143+
assert exitCode == 0;
144+
}
145+
146+
96147
@ParameterizedTest
97148
@CsvSource({"0,true", "0,false", "5, true", "5, false"})
98149
public void testZstdLibrary(int clevel, boolean checksum) throws IOException, InterruptedException {
@@ -295,8 +346,8 @@ public void testTransposeCodec() throws ZarrException {
295346
transposeCodecWrongOrder2.setCoreArrayMetadata(metadata);
296347
transposeCodecWrongOrder3.setCoreArrayMetadata(metadata);
297348

298-
assert ucar.ma2.MAMath.equals(testDataTransposed120, transposeCodec.encode(testData));
299-
assert ucar.ma2.MAMath.equals(testData, transposeCodec.decode(testDataTransposed120));
349+
assert MAMath.equals(testDataTransposed120, transposeCodec.encode(testData));
350+
assert MAMath.equals(testData, transposeCodec.decode(testDataTransposed120));
300351
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder1.encode(testData));
301352
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder2.encode(testData));
302353
assertThrows(ZarrException.class, () -> transposeCodecWrongOrder3.encode(testData));
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import sys
2+
3+
import zstandard as zstd
4+
5+
data_path = sys.argv[1]
6+
expected = sys.argv[2]
7+
8+
with open(data_path, "rb") as f:
9+
compressed = f.read()
10+
11+
decompressed = zstd.ZstdDecompressor().decompress(compressed)
12+
number = int.from_bytes(decompressed, byteorder='big')
13+
assert number == int(expected)

0 commit comments

Comments
 (0)