Skip to content

Commit c2289a2

Browse files
add TBoolTest (mix of TStringTest and NumericTypesTestBase)
1 parent 89703f7 commit c2289a2

File tree

1 file changed

+142
-0
lines changed
  • tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types

1 file changed

+142
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
package org.tensorflow.types;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertNotNull;
5+
6+
import org.junit.jupiter.api.Test;
7+
import org.tensorflow.EagerSession;
8+
import org.tensorflow.ndarray.NdArray;
9+
import org.tensorflow.ndarray.NdArrays;
10+
import org.tensorflow.ndarray.Shape;
11+
import org.tensorflow.ndarray.index.Indices;
12+
import org.tensorflow.op.Ops;
13+
import org.tensorflow.op.core.Constant;
14+
import org.tensorflow.op.math.LogicalAnd;
15+
import org.tensorflow.op.math.LogicalNot;
16+
import org.tensorflow.op.math.LogicalOr;
17+
18+
class TBoolTest {
19+
20+
@Test
21+
void createScalar() {
22+
TBool tensorT = TBool.scalarOf(true);
23+
assertNotNull(tensorT);
24+
assertEquals(Shape.scalar(), tensorT.shape());
25+
assertEquals(true, tensorT.getObject());
26+
27+
TBool tensorF = TBool.scalarOf(false);
28+
assertNotNull(tensorF);
29+
assertEquals(Shape.scalar(), tensorF.shape());
30+
assertEquals(false, tensorF.getObject());
31+
}
32+
33+
@Test
34+
void createVector() {
35+
TBool tensor = TBool.vectorOf(true, false);
36+
assertNotNull(tensor);
37+
assertEquals(Shape.of(2), tensor.shape());
38+
assertEquals(true, tensor.getObject(0));
39+
assertEquals(false, tensor.getObject(1));
40+
}
41+
42+
@Test
43+
void createCopy() {
44+
NdArray<Boolean> bools =
45+
NdArrays.ofObjects(Boolean.class, Shape.of(2, 2))
46+
.setObject(true, 0, 0)
47+
.setObject(false, 0, 1)
48+
.setObject(false, 1, 0)
49+
.setObject(true, 1, 1);
50+
51+
TBool tensor = TBool.tensorOf(bools);
52+
assertNotNull(tensor);
53+
bools
54+
.scalars()
55+
.forEachIndexed((idx, s) -> assertEquals(s.getObject(), tensor.getObject(idx)));
56+
}
57+
58+
@Test
59+
void initializeTensorsWithBools() {
60+
// Allocate a tensor of booleans of the shape (2, 3, 2)
61+
TBool tensor = TBool.tensorOf(Shape.of(2, 3, 2));
62+
63+
assertEquals(3, tensor.rank());
64+
assertEquals(12, tensor.size());
65+
NdArray<Boolean> data = (NdArray<Boolean>) tensor;
66+
67+
try (EagerSession session = EagerSession.create()) {
68+
Ops tf = Ops.create(session);
69+
70+
// Initialize tensor memory with falses and take a snapshot
71+
data.scalars().forEach(scalar -> ((NdArray<Boolean>) scalar).setObject(false));
72+
Constant<TBool> x = tf.constantOf(tensor);
73+
74+
// Initialize the same tensor memory with trues and take a snapshot
75+
data.scalars().forEach(scalar -> ((NdArray<Boolean>) scalar).setObject(true));
76+
Constant<TBool> y = tf.constantOf(tensor);
77+
78+
// Calculate x AND y and validate the result
79+
LogicalAnd xAndY = tf.math.logicalAnd(x, y);
80+
((NdArray<Boolean>) xAndY.asTensor())
81+
.scalars()
82+
.forEach(scalar -> assertEquals(false, scalar.getObject()));
83+
84+
// Calculate x AND y and validate the result
85+
LogicalOr xOrY = tf.math.logicalOr(x, y);
86+
((NdArray<Boolean>) xOrY.asTensor())
87+
.scalars()
88+
.forEach(scalar -> assertEquals(true, scalar.getObject()));
89+
90+
// Calculate !x and validate the result against y
91+
LogicalNot notX = tf.math.logicalNot(x);
92+
assertEquals(y.asTensor(), notX.asTensor());
93+
}
94+
}
95+
96+
@Test
97+
void setAndCompute() {
98+
NdArray<Boolean> heapData =
99+
NdArrays.ofBooleans(Shape.of(4))
100+
.setObject(true, 0)
101+
.setObject(false, 1)
102+
.setObject(true, 2)
103+
.setObject(false, 3);
104+
105+
// Creates a 2x2 matrix
106+
try (TBool tensor = TBool.tensorOf(Shape.of(2, 2))) {
107+
NdArray<Boolean> data = (NdArray<Boolean>) tensor;
108+
109+
// Copy first 2 values of the vector to the first row of the matrix
110+
data.set(heapData.slice(Indices.range(0, 2)), 0);
111+
112+
// Copy values at an odd position in the vector as the second row of the matrix
113+
data.set(heapData.slice(Indices.odd()), 1);
114+
115+
assertEquals(true, data.getObject(0, 0));
116+
assertEquals(false, data.getObject(0, 1));
117+
assertEquals(false, data.getObject(1, 0));
118+
assertEquals(false, data.getObject(1, 1));
119+
120+
// Read rows of the tensor in reverse order
121+
NdArray<Boolean> flippedData = data.slice(Indices.flip(), Indices.flip());
122+
123+
assertEquals(false, flippedData.getObject(0, 0));
124+
assertEquals(false, flippedData.getObject(0, 1));
125+
assertEquals(false, flippedData.getObject(1, 0));
126+
assertEquals(true, flippedData.getObject(1, 1));
127+
128+
try (EagerSession session = EagerSession.create()) {
129+
Ops tf = Ops.create(session);
130+
131+
LogicalNot sub = tf.math.logicalNot(tf.constantOf(tensor));
132+
NdArray<Boolean> result = (NdArray<Boolean>) sub.asTensor();
133+
134+
assertEquals(false, result.getObject(0, 0));
135+
assertEquals(true, result.getObject(0, 1));
136+
assertEquals(true, result.getObject(1, 0));
137+
assertEquals(true, result.getObject(1, 1));
138+
}
139+
}
140+
}
141+
142+
}

0 commit comments

Comments
 (0)