|
1 | 1 | package org.tensorics.gson.adapters;
|
2 | 2 |
|
3 | 3 | import com.google.common.annotations.VisibleForTesting;
|
4 |
| -import com.google.common.collect.Iterables; |
5 | 4 | import com.google.gson.Gson;
|
6 | 5 | import com.google.gson.TypeAdapter;
|
| 6 | +import com.google.gson.internal.JsonReaderInternalAccess; |
7 | 7 | import com.google.gson.reflect.TypeToken;
|
8 | 8 | import com.google.gson.stream.JsonReader;
|
9 | 9 | import com.google.gson.stream.JsonWriter;
|
10 | 10 | import org.tensorics.core.lang.Tensorics;
|
| 11 | +import org.tensorics.core.tensor.Position; |
11 | 12 | import org.tensorics.core.tensor.Tensor;
|
12 |
| -import org.tensorics.core.tensor.operations.TensorInternals; |
13 | 13 | import org.tensorics.core.tensorbacked.Tensorbacked;
|
14 | 14 | import org.tensorics.core.tensorbacked.TensorbackedInternals;
|
| 15 | +import org.tensorics.gson.util.Nestmaps; |
15 | 16 |
|
16 | 17 | import java.io.IOException;
|
| 18 | +import java.util.HashMap; |
17 | 19 | import java.util.List;
|
18 | 20 | import java.util.Map;
|
19 | 21 |
|
20 | 22 | import static java.util.Objects.requireNonNull;
|
| 23 | +import static org.tensorics.core.tensorbacked.TensorbackedInternals.valueTypeFrom; |
21 | 24 |
|
22 | 25 | public class TensorbackedGsonAdapter<V, TB extends Tensorbacked<V>> extends TypeAdapter<TB> {
|
23 | 26 |
|
24 | 27 | private final Gson context;
|
25 | 28 | private final Class<TB> tensorbackedClass;
|
26 | 29 |
|
| 30 | + private final TypeAdapter<V> valueAdapter; |
| 31 | + |
27 | 32 | public TensorbackedGsonAdapter(Gson context, Class<TB> tensorbackedClass) {
|
28 | 33 | this.context = context;
|
29 | 34 | this.tensorbackedClass = requireNonNull(tensorbackedClass, "tensorbackedClass must not be null.");
|
| 35 | + |
| 36 | + this.valueAdapter = adapterFor(valueTypeFrom(tensorbackedClass)); |
30 | 37 | }
|
31 | 38 |
|
32 | 39 | @Override
|
33 | 40 | public void write(JsonWriter out, TB value) throws IOException {
|
34 | 41 | /*XXX: The context of the tensor will currently NOT be serialized! */
|
35 | 42 |
|
36 | 43 | List<Class<?>> dimensions = TensorbackedInternals.dimensionListFrom(tensorbackedClass);
|
37 |
| - Object nested = nestmap(value.tensor(), dimensions); |
| 44 | + Object nested = Nestmaps.nestmap(value.tensor(), dimensions); |
38 | 45 | if (nested instanceof Map) {
|
39 | 46 | TypeAdapter<Map<?, ?>> adapter = context.getAdapter(new TypeToken<Map<?, ?>>() {
|
40 | 47 | });
|
41 | 48 | adapter.write(out, (Map<?, ?>) nested);
|
42 | 49 | } else { /* This is the special case of a scalar */
|
43 |
| - Class<V> valueType = TensorbackedInternals.valueTypeFrom(tensorbackedClass); |
44 |
| - TypeAdapter<V> adapter = context.getAdapter(TypeToken.get(valueType)); |
45 |
| - adapter.write(out, (V) nested); |
| 50 | + valueAdapter.write(out, (V) nested); |
46 | 51 | }
|
47 | 52 | }
|
48 | 53 |
|
| 54 | + @Override |
| 55 | + public TB read(JsonReader in) throws IOException { |
| 56 | + List<Class<?>> dimensions = TensorbackedInternals.dimensionListFrom(tensorbackedClass); |
| 57 | + Object object = recursiveRead(in, dimensions); |
| 58 | + Tensor<V> unnested = Nestmaps.unnestmap(object, dimensions); |
| 59 | + return TensorbackedInternals.createBackedByTensor(tensorbackedClass, unnested); |
| 60 | + } |
| 61 | + |
49 | 62 | @VisibleForTesting
|
50 |
| - static Object nestmap(Tensor<?> tensor, List<Class<?>> dimensions) { |
51 |
| - int tensorDimensionality = Tensorics.dimensionsOf(tensor).size(); |
52 |
| - if (tensorDimensionality != dimensions.size()) { |
53 |
| - throw new IllegalArgumentException("Tensor dimensionality (" + tensorDimensionality + |
54 |
| - ") and number of provided dimensions (" + dimensions.size() + ": " + dimensions + |
55 |
| - ") do not match!"); |
| 63 | + Object recursiveRead(JsonReader in, List<Class<?>> dimensions) throws IOException { |
| 64 | + if (dimensions.isEmpty()) { |
| 65 | + /* This is the special case of a scalar and the final value */ |
| 66 | + return valueAdapter.read(in); |
| 67 | + } else { |
| 68 | + Class<?> thisDim = dimensions.get(0); |
| 69 | + List<Class<?>> remainingDims = dimensions.subList(1, dimensions.size()); |
| 70 | + return readMap(in, thisDim, remainingDims); |
56 | 71 | }
|
| 72 | + } |
57 | 73 |
|
58 |
| - if (dimensions.isEmpty()) { |
59 |
| - return Tensorics.from(tensor).optional().orElse(null); |
| 74 | + private <T> Map<T, Object> readMap(JsonReader in, Class<T> keyDim, List<Class<?>> remainingDimensions) throws IOException { |
| 75 | + TypeAdapter<T> dimAdapter = adapterFor(keyDim); |
| 76 | + Map<T, Object> map = new HashMap<>(); |
| 77 | + |
| 78 | + in.beginObject(); |
| 79 | + while (in.hasNext()) { |
| 80 | + JsonReaderInternalAccess.INSTANCE.promoteNameToValue(in); |
| 81 | + T key = dimAdapter.read(in); |
| 82 | + Object value = recursiveRead(in, remainingDimensions); |
| 83 | + map.put(key, value); |
60 | 84 | }
|
| 85 | + in.endObject(); |
| 86 | + |
| 87 | + return map; |
| 88 | + } |
61 | 89 |
|
62 |
| - Class<?> dimension = Iterables.getLast(dimensions); |
63 |
| - Tensor<? extends Map<?, ?>> mappedOut = TensorInternals.mapOut(tensor).inDirectionOf(dimension); |
64 | 90 |
|
65 |
| - List<Class<?>> remainingDimensions = dimensions.subList(0, dimensions.size() - 1); |
66 |
| - return nestmap(mappedOut, remainingDimensions); |
| 91 | + private <T> TypeAdapter<T> adapterFor(Class<T> valueType) { |
| 92 | + return context.getAdapter(TypeToken.get(valueType)); |
67 | 93 | }
|
68 | 94 |
|
69 |
| - @Override |
70 |
| - public TB read(JsonReader in) throws IOException { |
| 95 | + |
| 96 | + /* |
| 97 | + JsonToken peek = in.peek(); |
| 98 | + if (peek == JsonToken.NULL) { |
| 99 | + in.nextNull(); |
71 | 100 | return null;
|
72 |
| - } |
| 101 | + } |
| 102 | +
|
| 103 | + Map<K, V> map = constructor.construct(); |
| 104 | +
|
| 105 | + if (peek == JsonToken.BEGIN_ARRAY) { |
| 106 | + in.beginArray(); |
| 107 | + while (in.hasNext()) { |
| 108 | + in.beginArray(); // entry array |
| 109 | + K key = keyTypeAdapter.read(in); |
| 110 | + V value = valueTypeAdapter.read(in); |
| 111 | + V replaced = map.put(key, value); |
| 112 | + if (replaced != null) { |
| 113 | + throw new JsonSyntaxException("duplicate key: " + key); |
| 114 | + } |
| 115 | + in.endArray(); |
| 116 | + } |
| 117 | + in.endArray(); |
| 118 | + } |
| 119 | + */ |
73 | 120 | }
|
0 commit comments