Skip to content

Commit 6af6e3f

Browse files
committed
partialDecode for sharding
1 parent 0f7bc1b commit 6af6e3f

File tree

16 files changed

+296
-117
lines changed

16 files changed

+296
-117
lines changed

src/main/java/com/scalableminds/zarrjava/store/S3Store.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public ByteBuffer get(String[] keys, long start, long end) {
7373

7474
@Override
7575
public void set(String[] keys, ByteBuffer bytes) {
76-
try (InputStream byteStream = new ByteArrayInputStream(bytes.array())) {
76+
try (InputStream byteStream = new ByteArrayInputStream(Utils.toArray(bytes))) {
7777
s3client.putObject(bucketName, resolveKeys(keys), byteStream, new ObjectMetadata());
7878
} catch (IOException e) {
7979
throw new RuntimeException(e);

src/main/java/com/scalableminds/zarrjava/utils/IndexingUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public static ChunkProjection computeProjection(
8787

8888
if (selOffset[dimIdx] + selShape[dimIdx] > dimLimit) {
8989
// selection ends after current chunk
90-
shape[dimIdx] = (int) (chunkShape[dimIdx] - selOffset[dimIdx]);
90+
shape[dimIdx] = (int) (chunkShape[dimIdx] - (selOffset[dimIdx] % chunkShape[dimIdx]));
9191
} else {
9292
// selection ends within current chunk
9393
shape[dimIdx] = (int) (selOffset[dimIdx] + selShape[dimIdx] - dimOffset

src/main/java/com/scalableminds/zarrjava/utils/Utils.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ public static int[] toIntArray(long[] array) {
5252
.toArray();
5353
}
5454

55+
public static byte[] toArray(ByteBuffer buffer) {
56+
byte[] bytes = new byte[buffer.remaining()];
57+
buffer.get(bytes);
58+
return bytes;
59+
}
60+
5561
public static <T> Stream<T> asStream(Iterator<T> sourceIterator) {
5662
Iterable<T> iterable = () -> sourceIterator;
5763
return StreamSupport.stream(iterable.spliterator(), false);

src/main/java/com/scalableminds/zarrjava/v2/Array.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.fasterxml.jackson.databind.ObjectMapper;
44
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
55
import com.scalableminds.zarrjava.store.StoreHandle;
6+
import com.scalableminds.zarrjava.utils.Utils;
67
import java.io.IOException;
78
import java.util.Arrays;
89
import java.util.stream.Collectors;
@@ -19,9 +20,7 @@ public class Array {
1920
ObjectMapper objectMapper = new ObjectMapper();
2021
objectMapper.registerModule(new Jdk8Module());
2122
this.metadata = objectMapper.readValue(
22-
storeHandle.resolve(ZARRAY)
23-
.readNonNull()
24-
.array(),
23+
Utils.toArray(storeHandle.resolve(ZARRAY).readNonNull()),
2524
ArrayMetadata.class
2625
);
2726
}

src/main/java/com/scalableminds/zarrjava/v3/Array.java

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ public static Array open(StoreHandle storeHandle) throws IOException, ZarrExcept
3333
storeHandle,
3434
Node.makeObjectMapper()
3535
.readValue(
36-
storeHandle.resolve(ZARR_JSON)
37-
.readNonNull()
38-
.array(),
36+
Utils.toArray(storeHandle.resolve(ZARR_JSON).readNonNull()),
3937
ArrayMetadata.class
4038
)
4139
);
@@ -96,27 +94,55 @@ public ucar.ma2.Array read(final long[] offset, final int[] shape) throws ZarrEx
9694
shape
9795
);
9896

99-
final ucar.ma2.Array chunkArray = readChunk(chunkCoords);
100-
MultiArrayUtils.copyRegion(chunkArray, chunkProjection.chunkOffset, outputArray,
101-
chunkProjection.outOffset, chunkProjection.shape
102-
);
97+
if (chunkIsInArray(chunkCoords)) {
98+
MultiArrayUtils.copyRegion(metadata.allocateFillValueChunk(),
99+
chunkProjection.chunkOffset, outputArray, chunkProjection.outOffset,
100+
chunkProjection.shape
101+
);
102+
}
103+
104+
final String[] chunkKeys = metadata.chunkKeyEncoding.encodeChunkKey(chunkCoords);
105+
final StoreHandle chunkHandle = storeHandle.resolve(chunkKeys);
106+
107+
if (codecPipeline.supportsPartialDecode()) {
108+
System.out.println("decodePartial");
109+
final ucar.ma2.Array chunkArray = codecPipeline.decodePartial(chunkHandle,
110+
Utils.toLongArray(chunkProjection.chunkOffset), chunkProjection.shape,
111+
metadata.coreArrayMetadata);
112+
MultiArrayUtils.copyRegion(chunkArray, new int[metadata.ndim()], outputArray,
113+
chunkProjection.outOffset, chunkProjection.shape
114+
);
115+
} else {
116+
System.out.println("decode");
117+
MultiArrayUtils.copyRegion(readChunk(chunkCoords), chunkProjection.chunkOffset,
118+
outputArray, chunkProjection.outOffset, chunkProjection.shape
119+
);
120+
}
121+
103122
} catch (ZarrException e) {
104123
throw new RuntimeException(e);
105124
}
106125
});
107126
return outputArray;
108127
}
109128

110-
@Nonnull
111-
public ucar.ma2.Array readChunk(long[] chunkCoords) throws ZarrException {
129+
boolean chunkIsInArray(long[] chunkCoords) {
112130
final int[] chunkShape = metadata.chunkShape();
113-
114131
for (int dimIdx = 0; dimIdx < metadata.ndim(); dimIdx++) {
115132
if (chunkCoords[dimIdx] < 0
116133
|| chunkCoords[dimIdx] * chunkShape[dimIdx] >= metadata.shape[dimIdx]) {
117-
return metadata.allocateFillValueChunk();
134+
return false;
118135
}
119136
}
137+
return true;
138+
}
139+
140+
@Nonnull
141+
public ucar.ma2.Array readChunk(long[] chunkCoords)
142+
throws ZarrException {
143+
if (chunkIsInArray(chunkCoords)) {
144+
return metadata.allocateFillValueChunk();
145+
}
120146

121147
final String[] chunkKeys = metadata.chunkKeyEncoding.encodeChunkKey(chunkCoords);
122148
final StoreHandle chunkHandle = storeHandle.resolve(chunkKeys);
@@ -126,8 +152,7 @@ public ucar.ma2.Array readChunk(long[] chunkCoords) throws ZarrException {
126152
return metadata.allocateFillValueChunk();
127153
}
128154

129-
ucar.ma2.Array chunkArray = codecPipeline.decode(chunkBytes, metadata.coreArrayMetadata);
130-
return chunkArray;
155+
return codecPipeline.decode(chunkBytes, metadata.coreArrayMetadata);
131156
}
132157

133158
public void write(ucar.ma2.Array array) {

src/main/java/com/scalableminds/zarrjava/v3/Group.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.fasterxml.jackson.databind.ObjectMapper;
44
import com.scalableminds.zarrjava.ZarrException;
55
import com.scalableminds.zarrjava.store.StoreHandle;
6+
import com.scalableminds.zarrjava.utils.Utils;
67
import java.io.IOException;
78
import java.nio.ByteBuffer;
89
import java.util.Map;
@@ -25,7 +26,7 @@ public static Group open(@Nonnull StoreHandle storeHandle) throws IOException {
2526
StoreHandle metadataHandle = storeHandle.resolve(ZARR_JSON);
2627
ByteBuffer metadataBytes = metadataHandle.readNonNull();
2728
return new Group(storeHandle, Node.makeObjectMapper()
28-
.readValue(metadataBytes.array(), GroupMetadata.class));
29+
.readValue(Utils.toArray(metadataBytes), GroupMetadata.class));
2930
}
3031

3132
public static Group create(
@@ -58,9 +59,9 @@ public Node get(String key) throws ZarrException {
5859
if (metadataBytes == null) {
5960
return null;
6061
}
61-
byte[] metadataBytearray = metadataBytes.array();
62+
byte[] metadataBytearray = Utils.toArray(metadataBytes);
6263
try {
63-
String nodeType = objectMapper.readTree(metadataBytes.array())
64+
String nodeType = objectMapper.readTree(metadataBytearray)
6465
.get("node_type")
6566
.asText();
6667
switch (nodeType) {
Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.scalableminds.zarrjava.v3.codec;
22

33
import com.scalableminds.zarrjava.ZarrException;
4+
import com.scalableminds.zarrjava.store.StoreHandle;
45
import com.scalableminds.zarrjava.v3.ArrayMetadata;
56
import java.nio.ByteBuffer;
67
import ucar.ma2.Array;
@@ -13,20 +14,12 @@ ByteBuffer encode(Array chunkArray, ArrayMetadata.CoreArrayMetadata arrayMetadat
1314
Array decode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata arrayMetadata)
1415
throws ZarrException;
1516

16-
interface WithPartialEncode extends ArrayBytesCodec {
17-
18-
ByteBuffer encodePartial(
19-
Array chunkArray, long[] offset, int[] shape,
20-
ArrayMetadata.CoreArrayMetadata arrayMetadata
21-
);
22-
}
23-
2417
interface WithPartialDecode extends ArrayBytesCodec {
2518

26-
Array partialDecode(
27-
ByteBuffer chunkBytes, long[] offset, int[] shape,
19+
Array decodePartial(
20+
StoreHandle handle, long[] offset, int[] shape,
2821
ArrayMetadata.CoreArrayMetadata arrayMetadata
29-
);
22+
) throws ZarrException;
3023
}
3124
}
3225

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package com.scalableminds.zarrjava.v3.codec;
22

33
import com.fasterxml.jackson.annotation.JsonTypeInfo;
4+
import com.scalableminds.zarrjava.ZarrException;
5+
import com.scalableminds.zarrjava.v3.ArrayMetadata;
46

57
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "name")
68
public interface Codec {
79

10+
long computeEncodedSize(long inputByteLength, ArrayMetadata.CoreArrayMetadata arrayMetadata)
11+
throws ZarrException;
812
}
913

src/main/java/com/scalableminds/zarrjava/v3/codec/CodecPipeline.java

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.scalableminds.zarrjava.ZarrException;
44
import com.scalableminds.zarrjava.store.StoreHandle;
55
import com.scalableminds.zarrjava.v3.ArrayMetadata;
6+
import com.scalableminds.zarrjava.v3.ArrayMetadata.CoreArrayMetadata;
67
import java.nio.ByteBuffer;
78
import java.util.Arrays;
89
import javax.annotation.Nonnull;
@@ -72,6 +73,28 @@ BytesBytesCodec[] getBytesBytesCodecs() {
7273
.toArray(BytesBytesCodec[]::new);
7374
}
7475

76+
public boolean supportsPartialDecode() {
77+
return codecs.length == 1 && codecs[0] instanceof ArrayBytesCodec.WithPartialDecode;
78+
}
79+
80+
@Nonnull
81+
public Array decodePartial(
82+
@Nonnull StoreHandle storeHandle,
83+
long[] offset, int[] shape,
84+
@Nonnull ArrayMetadata.CoreArrayMetadata arrayMetadata
85+
) throws ZarrException {
86+
if (!supportsPartialDecode()) {
87+
throw new ZarrException(
88+
"Partial decode is not supported for these codecs. " + Arrays.toString(codecs));
89+
}
90+
Array chunkArray = ((ArrayBytesCodec.WithPartialDecode) getArrayBytesCodec()).decodePartial(
91+
storeHandle, offset, shape, arrayMetadata);
92+
if (chunkArray == null) {
93+
throw new ZarrException("chunkArray is null. This is likely a bug in one of the codecs.");
94+
}
95+
return chunkArray;
96+
}
97+
7598
@Nonnull
7699
public Array decode(
77100
@Nonnull ByteBuffer chunkBytes,
@@ -80,9 +103,13 @@ public Array decode(
80103
if (chunkBytes == null) {
81104
throw new ZarrException("chunkBytes is null. Ohh nooo.");
82105
}
83-
for (BytesBytesCodec codec : getBytesBytesCodecs()) { // TODO iterate in reverse
106+
107+
BytesBytesCodec[] bytesBytesCodecs = getBytesBytesCodecs();
108+
for (int i = bytesBytesCodecs.length - 1; i >= 0; --i) {
109+
BytesBytesCodec codec = bytesBytesCodecs[i];
84110
chunkBytes = codec.decode(chunkBytes, arrayMetadata);
85111
}
112+
86113
if (chunkBytes == null) {
87114
throw new ZarrException(
88115
"chunkBytes is null. This is likely a bug in one of the codecs. " + Arrays.toString(
@@ -92,9 +119,13 @@ public Array decode(
92119
if (chunkArray == null) {
93120
throw new ZarrException("chunkArray is null. This is likely a bug in one of the codecs.");
94121
}
95-
for (ArrayArrayCodec codec : getArrayArrayCodecs()) { // TODO iterate in reverse
122+
123+
ArrayArrayCodec[] arrayArrayCodecs = getArrayArrayCodecs();
124+
for (int i = arrayArrayCodecs.length - 1; i >= 0; --i) {
125+
ArrayArrayCodec codec = arrayArrayCodecs[i];
96126
chunkArray = codec.decode(chunkArray, arrayMetadata);
97127
}
128+
98129
if (chunkArray == null) {
99130
throw new ZarrException("chunkArray is null. This is likely a bug in one of the codecs.");
100131
}
@@ -117,6 +148,14 @@ public ByteBuffer encode(
117148
return chunkBytes;
118149
}
119150

151+
public long computeEncodedSize(long inputByteLength, CoreArrayMetadata arrayMetadata)
152+
throws ZarrException {
153+
for (Codec codec : codecs) {
154+
inputByteLength = codec.computeEncodedSize(inputByteLength, arrayMetadata);
155+
}
156+
return inputByteLength;
157+
}
158+
120159
public Array partialDecode(
121160
StoreHandle valueHandle, long[] offset, int[] shape,
122161
ArrayMetadata.CoreArrayMetadata arrayMetadata

src/main/java/com/scalableminds/zarrjava/v3/codec/core/BloscCodec.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
1414
import com.scalableminds.bloscjava.Blosc;
1515
import com.scalableminds.zarrjava.ZarrException;
16+
import com.scalableminds.zarrjava.utils.Utils;
1617
import com.scalableminds.zarrjava.v3.ArrayMetadata;
1718
import com.scalableminds.zarrjava.v3.codec.BytesBytesCodec;
1819
import java.io.IOException;
@@ -35,7 +36,7 @@ public BloscCodec(
3536
public ByteBuffer decode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata arrayMetadata)
3637
throws ZarrException {
3738
try {
38-
return ByteBuffer.wrap(Blosc.decompress(chunkBytes.array()));
39+
return ByteBuffer.wrap(Blosc.decompress(Utils.toArray(chunkBytes)));
3940
} catch (Exception ex) {
4041
throw new ZarrException("Error in decoding blosc.", ex);
4142
}
@@ -46,7 +47,7 @@ public ByteBuffer encode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata
4647
throws ZarrException {
4748
try {
4849
return ByteBuffer.wrap(
49-
Blosc.compress(chunkBytes.array(), configuration.typesize, configuration.cname,
50+
Blosc.compress(Utils.toArray(chunkBytes), configuration.typesize, configuration.cname,
5051
configuration.clevel,
5152
configuration.shuffle, configuration.blocksize
5253
));
@@ -55,6 +56,12 @@ public ByteBuffer encode(ByteBuffer chunkBytes, ArrayMetadata.CoreArrayMetadata
5556
}
5657
}
5758

59+
@Override
60+
public long computeEncodedSize(long inputByteLength,
61+
ArrayMetadata.CoreArrayMetadata arrayMetadata) throws ZarrException {
62+
throw new ZarrException("Not implemented for Blosc codec.");
63+
}
64+
5865
public static final class CustomShuffleSerializer extends StdSerializer<Blosc.Shuffle> {
5966

6067
public CustomShuffleSerializer() {

0 commit comments

Comments
 (0)