1+ package org .tensorflow .types ;
2+
3+ import static org .junit .jupiter .api .Assertions .assertEquals ;
4+ import static org .junit .jupiter .api .Assertions .assertNotNull ;
5+
6+ import org .junit .jupiter .api .Test ;
7+ import org .tensorflow .EagerSession ;
8+ import org .tensorflow .ndarray .NdArray ;
9+ import org .tensorflow .ndarray .NdArrays ;
10+ import org .tensorflow .ndarray .Shape ;
11+ import org .tensorflow .ndarray .index .Indices ;
12+ import org .tensorflow .op .Ops ;
13+ import org .tensorflow .op .core .Constant ;
14+ import org .tensorflow .op .math .LogicalAnd ;
15+ import org .tensorflow .op .math .LogicalNot ;
16+ import org .tensorflow .op .math .LogicalOr ;
17+
18+ class TBoolTest {
19+
20+ @ Test
21+ void createScalar () {
22+ TBool tensorT = TBool .scalarOf (true );
23+ assertNotNull (tensorT );
24+ assertEquals (Shape .scalar (), tensorT .shape ());
25+ assertEquals (true , tensorT .getObject ());
26+
27+ TBool tensorF = TBool .scalarOf (false );
28+ assertNotNull (tensorF );
29+ assertEquals (Shape .scalar (), tensorF .shape ());
30+ assertEquals (false , tensorF .getObject ());
31+ }
32+
33+ @ Test
34+ void createVector () {
35+ TBool tensor = TBool .vectorOf (true , false );
36+ assertNotNull (tensor );
37+ assertEquals (Shape .of (2 ), tensor .shape ());
38+ assertEquals (true , tensor .getObject (0 ));
39+ assertEquals (false , tensor .getObject (1 ));
40+ }
41+
42+ @ Test
43+ void createCopy () {
44+ NdArray <Boolean > bools =
45+ NdArrays .ofObjects (Boolean .class , Shape .of (2 , 2 ))
46+ .setObject (true , 0 , 0 )
47+ .setObject (false , 0 , 1 )
48+ .setObject (false , 1 , 0 )
49+ .setObject (true , 1 , 1 );
50+
51+ TBool tensor = TBool .tensorOf (bools );
52+ assertNotNull (tensor );
53+ bools
54+ .scalars ()
55+ .forEachIndexed ((idx , s ) -> assertEquals (s .getObject (), tensor .getObject (idx )));
56+ }
57+
58+ @ Test
59+ void initializeTensorsWithBools () {
60+ // Allocate a tensor of booleans of the shape (2, 3, 2)
61+ TBool tensor = TBool .tensorOf (Shape .of (2 , 3 , 2 ));
62+
63+ assertEquals (3 , tensor .rank ());
64+ assertEquals (12 , tensor .size ());
65+ NdArray <Boolean > data = (NdArray <Boolean >) tensor ;
66+
67+ try (EagerSession session = EagerSession .create ()) {
68+ Ops tf = Ops .create (session );
69+
70+ // Initialize tensor memory with falses and take a snapshot
71+ data .scalars ().forEach (scalar -> ((NdArray <Boolean >) scalar ).setObject (false ));
72+ Constant <TBool > x = tf .constantOf (tensor );
73+
74+ // Initialize the same tensor memory with trues and take a snapshot
75+ data .scalars ().forEach (scalar -> ((NdArray <Boolean >) scalar ).setObject (true ));
76+ Constant <TBool > y = tf .constantOf (tensor );
77+
78+ // Calculate x AND y and validate the result
79+ LogicalAnd xAndY = tf .math .logicalAnd (x , y );
80+ ((NdArray <Boolean >) xAndY .asTensor ())
81+ .scalars ()
82+ .forEach (scalar -> assertEquals (false , scalar .getObject ()));
83+
84+ // Calculate x AND y and validate the result
85+ LogicalOr xOrY = tf .math .logicalOr (x , y );
86+ ((NdArray <Boolean >) xOrY .asTensor ())
87+ .scalars ()
88+ .forEach (scalar -> assertEquals (true , scalar .getObject ()));
89+
90+ // Calculate !x and validate the result against y
91+ LogicalNot notX = tf .math .logicalNot (x );
92+ assertEquals (y .asTensor (), notX .asTensor ());
93+ }
94+ }
95+
96+ @ Test
97+ void setAndCompute () {
98+ NdArray <Boolean > heapData =
99+ NdArrays .ofBooleans (Shape .of (4 ))
100+ .setObject (true , 0 )
101+ .setObject (false , 1 )
102+ .setObject (true , 2 )
103+ .setObject (false , 3 );
104+
105+ // Creates a 2x2 matrix
106+ try (TBool tensor = TBool .tensorOf (Shape .of (2 , 2 ))) {
107+ NdArray <Boolean > data = (NdArray <Boolean >) tensor ;
108+
109+ // Copy first 2 values of the vector to the first row of the matrix
110+ data .set (heapData .slice (Indices .range (0 , 2 )), 0 );
111+
112+ // Copy values at an odd position in the vector as the second row of the matrix
113+ data .set (heapData .slice (Indices .odd ()), 1 );
114+
115+ assertEquals (true , data .getObject (0 , 0 ));
116+ assertEquals (false , data .getObject (0 , 1 ));
117+ assertEquals (false , data .getObject (1 , 0 ));
118+ assertEquals (false , data .getObject (1 , 1 ));
119+
120+ // Read rows of the tensor in reverse order
121+ NdArray <Boolean > flippedData = data .slice (Indices .flip (), Indices .flip ());
122+
123+ assertEquals (false , flippedData .getObject (0 , 0 ));
124+ assertEquals (false , flippedData .getObject (0 , 1 ));
125+ assertEquals (false , flippedData .getObject (1 , 0 ));
126+ assertEquals (true , flippedData .getObject (1 , 1 ));
127+
128+ try (EagerSession session = EagerSession .create ()) {
129+ Ops tf = Ops .create (session );
130+
131+ LogicalNot sub = tf .math .logicalNot (tf .constantOf (tensor ));
132+ NdArray <Boolean > result = (NdArray <Boolean >) sub .asTensor ();
133+
134+ assertEquals (false , result .getObject (0 , 0 ));
135+ assertEquals (true , result .getObject (0 , 1 ));
136+ assertEquals (true , result .getObject (1 , 0 ));
137+ assertEquals (true , result .getObject (1 , 1 ));
138+ }
139+ }
140+ }
141+
142+ }
0 commit comments