Skip to content

Commit 1375642

Browse files
committed
simplest version of tensorbackeds serialization works
1 parent fc64650 commit 1375642

File tree

6 files changed

+265
-0
lines changed

6 files changed

+265
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,6 @@
2121

2222
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
2323
hs_err_pid*
24+
/.gradle/
25+
/.idea/
26+
/build/
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package org.tensorics.gson.adapters;
2+
3+
import com.google.common.annotations.VisibleForTesting;
4+
import com.google.common.collect.ImmutableList;
5+
import com.google.common.collect.Iterables;
6+
import com.google.common.collect.Lists;
7+
import com.google.gson.Gson;
8+
import com.google.gson.TypeAdapter;
9+
import com.google.gson.reflect.TypeToken;
10+
import com.google.gson.stream.JsonReader;
11+
import com.google.gson.stream.JsonWriter;
12+
import org.tensorics.core.lang.Tensorics;
13+
import org.tensorics.core.tensor.Tensor;
14+
import org.tensorics.core.tensor.operations.TensorInternals;
15+
import org.tensorics.core.tensorbacked.Tensorbacked;
16+
import org.tensorics.core.tensorbacked.TensorbackedInternals;
17+
import org.tensorics.core.tensorbacked.Tensorbackeds;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
import java.util.Map;
22+
23+
import static java.util.Objects.requireNonNull;
24+
25+
public class TensorbackedGsonAdapter<V, TB extends Tensorbacked<V>> extends TypeAdapter<TB> {
26+
27+
private final Gson context;
28+
private final Class<TB> tensorbackedClass;
29+
30+
public TensorbackedGsonAdapter(Gson context, Class<TB> tensorbackedClass) {
31+
this.context = context;
32+
this.tensorbackedClass = requireNonNull(tensorbackedClass, "tensorbackedClass must not be null.");
33+
}
34+
35+
@Override
36+
public void write(JsonWriter out, TB value) throws IOException {
37+
/*XXX: The context of the tensor will currently NOT be serialized! */
38+
39+
List<Class<?>> dimensions = TensorbackedInternals.dimensionListFrom(tensorbackedClass);
40+
Object nested = nested(value.tensor(), dimensions);
41+
if (nested instanceof Map) {
42+
TypeAdapter<Map<?, ?>> adapter = context.getAdapter(new TypeToken<Map<?, ?>>() {
43+
});
44+
adapter.write(out, (Map<?, ?>) nested);
45+
} else { /* This is the special case of a scalar */
46+
Class<V> valueType = TensorbackedInternals.valueTypeFrom(tensorbackedClass);
47+
TypeAdapter<V> adapter = context.getAdapter(TypeToken.get(valueType));
48+
adapter.write(out, (V) nested);
49+
}
50+
}
51+
52+
@VisibleForTesting
53+
static Object nested(Tensor<?> tensor, List<Class<?>> dimensions) {
54+
if (Tensorics.dimensionsOf(tensor).size() != dimensions.size()) {
55+
throw new IllegalArgumentException("Tensor dimension and provided dimension do not match!");
56+
}
57+
58+
if (dimensions.isEmpty()) {
59+
return Tensorics.from(tensor).optional().orElse(null);
60+
}
61+
62+
Class<?> dimension = Iterables.getLast(dimensions);
63+
Tensor<? extends Map<?, ?>> mappedOut = TensorInternals.mapOut(tensor).inDirectionOf(dimension);
64+
65+
List<Class<?>> remainingDimensions = dimensions.subList(0, dimensions.size() - 1);
66+
return nested(mappedOut, remainingDimensions);
67+
}
68+
69+
@Override
70+
public TB read(JsonReader in) throws IOException {
71+
return null;
72+
}
73+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.tensorics.gson.adapters;
2+
3+
import com.google.gson.Gson;
4+
import com.google.gson.TypeAdapter;
5+
import com.google.gson.TypeAdapterFactory;
6+
import com.google.gson.reflect.TypeToken;
7+
import org.tensorics.core.tensorbacked.Tensorbacked;
8+
9+
public class TensorbackedGsonAdapterFactory implements TypeAdapterFactory {
10+
11+
12+
@Override
13+
public <T> TypeAdapter<T> create(Gson gson, TypeToken<T> type) {
14+
Class<? super T> rawType = type.getRawType();
15+
if (!Tensorbacked.class.isAssignableFrom(rawType)) {
16+
return null;
17+
}
18+
Class<? extends Tensorbacked<?>> tensorbackedClass = (Class<? extends Tensorbacked<?>>) rawType;
19+
return new TensorbackedGsonAdapter(gson, tensorbackedClass);
20+
}
21+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package org.tensorics.gson.adapters;
2+
3+
import com.google.common.collect.ImmutableList;
4+
import com.google.common.collect.ImmutableMap;
5+
import org.assertj.core.api.Assertions;
6+
import org.junit.Test;
7+
import org.tensorics.core.lang.Tensorics;
8+
import org.tensorics.core.tensor.ImmutableScalar;
9+
import org.tensorics.core.tensor.Scalar;
10+
import org.tensorics.core.tensor.Tensor;
11+
import org.tensorics.core.tensorbacked.AbstractTensorbacked;
12+
import org.tensorics.core.tensorbacked.TensorbackedInternals;
13+
import org.tensorics.core.tensorbacked.annotation.Dimensions;
14+
15+
import static org.tensorics.core.lang.Tensorics.at;
16+
17+
public class NestingTest {
18+
19+
@Test
20+
public void mapoutTensorWorks() {
21+
AnInheritedTensorbacked val = Tensorics.builderFor(AnInheritedTensorbacked.class)//
22+
.put(at("A", 1), 0.11)//
23+
.put(at("B", 1), 0.21)
24+
.put(at("A", 2), 0.12)//
25+
.put(at("B", 2), 0.22)
26+
.build();
27+
28+
Object nested = TensorbackedGsonAdapter.nested(val.tensor(), TensorbackedInternals.dimensionListFrom(AnInheritedTensorbacked.class));
29+
30+
Assertions.assertThat(nested).isEqualTo(ImmutableMap.of(
31+
"A", ImmutableMap.of(1, 0.11, 2, 0.12), //
32+
"B", ImmutableMap.of(1, 0.21, 2, 0.22)
33+
));
34+
}
35+
36+
37+
@Test
38+
public void mapoutScalarIsPlain() {
39+
Scalar<Double> scalar = Tensorics.scalarOf(0.33);
40+
41+
Object nested = TensorbackedGsonAdapter.nested(scalar, ImmutableList.of());
42+
Assertions.assertThat(nested).isEqualTo(0.33);
43+
}
44+
45+
@Dimensions({String.class, Integer.class})
46+
public static class AnInheritedTensorbacked extends AbstractTensorbacked<Double> {
47+
48+
public AnInheritedTensorbacked(Tensor<Double> tensor) {
49+
super(tensor);
50+
}
51+
52+
}
53+
54+
@Test
55+
public void mapoutEmptyTensorIsNull() {
56+
Tensor<Object> empty = Tensorics.builder(String.class, Integer.class).build();
57+
Object nested = TensorbackedGsonAdapter.nested(empty, ImmutableList.of(String.class, Integer.class));
58+
Assertions.assertThat(nested).isNull();
59+
}
60+
61+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package org.tensorics.gson.adapters;
2+
3+
import com.google.gson.Gson;
4+
import com.google.gson.TypeAdapter;
5+
import com.google.gson.TypeAdapterFactory;
6+
import com.google.gson.reflect.TypeToken;
7+
import org.assertj.core.api.Assertions;
8+
import org.junit.Test;
9+
import org.tensorics.core.lang.Tensorics;
10+
import org.tensorics.core.tensor.Tensor;
11+
import org.tensorics.core.tensorbacked.AbstractTensorbacked;
12+
import org.tensorics.core.tensorbacked.annotation.Dimensions;
13+
14+
import static org.tensorics.core.lang.Tensorics.at;
15+
16+
public class TensorbackedGsonAdapterFactoryTest {
17+
18+
private final TypeAdapterFactory factory = new TensorbackedGsonAdapterFactory();
19+
private final Gson gson = new Gson();
20+
21+
@Test
22+
public void noAdapterForNotATensorbacked() {
23+
TypeAdapter<InvalidTensorbacked> adapter = factory.create(gson, TypeToken.get(InvalidTensorbacked.class));
24+
Assertions.assertThat(adapter).isNotNull();
25+
/* The adapter returned here is not null, however, it will give problems when determining the dimensions...*/
26+
}
27+
28+
@Test
29+
public void adapterReturnedForTensorbacked() {
30+
TypeAdapter<AnInheritedTensorbacked> adapter = factory.create(gson, new TypeToken<AnInheritedTensorbacked>() {
31+
});
32+
Assertions.assertThat(adapter).isNotNull();
33+
}
34+
35+
@Dimensions({String.class, Integer.class})
36+
public static class AnInheritedTensorbacked extends AbstractTensorbacked<Double> {
37+
38+
public AnInheritedTensorbacked(Tensor<Double> tensor) {
39+
super(tensor);
40+
}
41+
42+
}
43+
44+
/**
45+
* Despite not inheriting from abstract tensorbacked, this is not valid tensorbacked object, as it does not have
46+
* the required annotation.
47+
*/
48+
public static class InvalidTensorbacked extends AbstractTensorbacked<Double> {
49+
50+
public InvalidTensorbacked(Tensor<Double> tensor) {
51+
super(tensor);
52+
}
53+
54+
}
55+
56+
57+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package org.tensorics.gson.adapters;
2+
3+
import com.google.gson.Gson;
4+
import com.google.gson.GsonBuilder;
5+
import org.junit.Test;
6+
import org.tensorics.core.lang.Tensorics;
7+
import org.tensorics.core.tensorbacked.dimtyped.TensorbackedScalar;
8+
9+
import static org.assertj.core.api.Assertions.assertThat;
10+
import static org.tensorics.core.lang.Tensorics.at;
11+
12+
public class TensorbackedGsonAdapterTest {
13+
14+
private final Gson gson = new GsonBuilder()//
15+
.registerTypeAdapterFactory(new TensorbackedGsonAdapterFactory())//
16+
.create();
17+
18+
@Test
19+
public void simpleTensorSerializationIsOk() {
20+
TensorbackedGsonAdapterFactoryTest.AnInheritedTensorbacked val = Tensorics.builderFor(TensorbackedGsonAdapterFactoryTest.AnInheritedTensorbacked.class)//
21+
.put(at("A", 1), 0.11)//
22+
.put(at("B", 1), 0.21)
23+
.put(at("A", 2), 0.12)//
24+
.put(at("B", 2), 0.22)
25+
.build();
26+
27+
String string = gson.toJson(val);
28+
System.out.println(string);
29+
assertThat(string).isNotNull();
30+
assertThat(string).isNotEmpty();
31+
assertThat(string).isEqualTo("{\"A\":{\"1\":0.11,\"2\":0.12},\"B\":{\"1\":0.21,\"2\":0.22}}");
32+
}
33+
34+
@Test
35+
public void simpleScalarSerializationIsOk() {
36+
AScalarBacked val = Tensorics.builderForScalar(AScalarBacked.class).put(0.33).build();
37+
38+
String string = gson.toJson(val);
39+
System.out.println(string);
40+
assertThat(string).isNotNull();
41+
assertThat(string).isNotEmpty();
42+
assertThat(string).isEqualTo("0.33");
43+
}
44+
45+
public interface AScalarBacked extends TensorbackedScalar<Double> {
46+
47+
}
48+
49+
50+
}

0 commit comments

Comments
 (0)