-
Notifications
You must be signed in to change notification settings - Fork 223
Feature/add t bool test #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
56c41de
0e67a8b
7141c8c
b87bfb2
3b2c7fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| package org.tensorflow.types; | ||
|
|
||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||
| import static org.junit.jupiter.api.Assertions.assertNotNull; | ||
|
|
||
| import org.junit.jupiter.api.Test; | ||
| import org.tensorflow.EagerSession; | ||
| import org.tensorflow.ndarray.NdArray; | ||
| import org.tensorflow.ndarray.NdArrays; | ||
| import org.tensorflow.ndarray.Shape; | ||
| import org.tensorflow.ndarray.index.Indices; | ||
| import org.tensorflow.op.Ops; | ||
| import org.tensorflow.op.core.Constant; | ||
| import org.tensorflow.op.math.LogicalAnd; | ||
| import org.tensorflow.op.math.LogicalNot; | ||
| import org.tensorflow.op.math.LogicalOr; | ||
|
|
||
| class TBoolTest { | ||
|
||
|
|
||
| @Test | ||
| void createScalar() { | ||
| TBool tensorT = TBool.scalarOf(true); | ||
| assertNotNull(tensorT); | ||
| assertEquals(Shape.scalar(), tensorT.shape()); | ||
| assertEquals(true, tensorT.getObject()); | ||
|
|
||
| TBool tensorF = TBool.scalarOf(false); | ||
| assertNotNull(tensorF); | ||
| assertEquals(Shape.scalar(), tensorF.shape()); | ||
| assertEquals(false, tensorF.getObject()); | ||
| } | ||
|
|
||
| @Test | ||
| void createVector() { | ||
| TBool tensor = TBool.vectorOf(true, false); | ||
| assertNotNull(tensor); | ||
| assertEquals(Shape.of(2), tensor.shape()); | ||
| assertEquals(true, tensor.getObject(0)); | ||
| assertEquals(false, tensor.getObject(1)); | ||
| } | ||
|
|
||
| @Test | ||
| void createCopy() { | ||
| NdArray<Boolean> bools = | ||
| NdArrays.ofObjects(Boolean.class, Shape.of(2, 2)) | ||
| .setObject(true, 0, 0) | ||
| .setObject(false, 0, 1) | ||
| .setObject(false, 1, 0) | ||
| .setObject(true, 1, 1); | ||
|
|
||
| TBool tensor = TBool.tensorOf(bools); | ||
| assertNotNull(tensor); | ||
| bools | ||
| .scalars() | ||
| .forEachIndexed((idx, s) -> assertEquals(s.getObject(), tensor.getObject(idx))); | ||
| } | ||
|
|
||
| @Test | ||
| void initializeTensorsWithBools() { | ||
| // Allocate a tensor of booleans of the shape (2, 3, 2) | ||
| TBool tensor = TBool.tensorOf(Shape.of(2, 3, 2)); | ||
|
|
||
| assertEquals(3, tensor.rank()); | ||
| assertEquals(12, tensor.size()); | ||
| NdArray<Boolean> data = (NdArray<Boolean>) tensor; | ||
|
|
||
| try (EagerSession session = EagerSession.create()) { | ||
| Ops tf = Ops.create(session); | ||
|
|
||
| // Initialize tensor memory with falses and take a snapshot | ||
| data.scalars().forEach(scalar -> ((NdArray<Boolean>) scalar).setObject(false)); | ||
| Constant<TBool> x = tf.constantOf(tensor); | ||
|
|
||
| // Initialize the same tensor memory with trues and take a snapshot | ||
| data.scalars().forEach(scalar -> ((NdArray<Boolean>) scalar).setObject(true)); | ||
| Constant<TBool> y = tf.constantOf(tensor); | ||
|
|
||
| // Calculate x AND y and validate the result | ||
| LogicalAnd xAndY = tf.math.logicalAnd(x, y); | ||
| ((NdArray<Boolean>) xAndY.asTensor()) | ||
| .scalars() | ||
| .forEach(scalar -> assertEquals(false, scalar.getObject())); | ||
|
|
||
| // Calculate x AND y and validate the result | ||
|
||
| LogicalOr xOrY = tf.math.logicalOr(x, y); | ||
| ((NdArray<Boolean>) xOrY.asTensor()) | ||
| .scalars() | ||
| .forEach(scalar -> assertEquals(true, scalar.getObject())); | ||
|
|
||
| // Calculate !x and validate the result against y | ||
| LogicalNot notX = tf.math.logicalNot(x); | ||
| assertEquals(y.asTensor(), notX.asTensor()); | ||
| } | ||
| } | ||
|
|
||
| @Test | ||
| void setAndCompute() { | ||
| NdArray<Boolean> heapData = | ||
| NdArrays.ofBooleans(Shape.of(4)) | ||
| .setObject(true, 0) | ||
| .setObject(false, 1) | ||
| .setObject(true, 2) | ||
| .setObject(false, 3); | ||
|
|
||
| // Creates a 2x2 matrix | ||
| try (TBool tensor = TBool.tensorOf(Shape.of(2, 2))) { | ||
| NdArray<Boolean> data = (NdArray<Boolean>) tensor; | ||
|
|
||
| // Copy first 2 values of the vector to the first row of the matrix | ||
| data.set(heapData.slice(Indices.range(0, 2)), 0); | ||
|
|
||
| // Copy values at an odd position in the vector as the second row of the matrix | ||
| data.set(heapData.slice(Indices.odd()), 1); | ||
|
|
||
| assertEquals(true, data.getObject(0, 0)); | ||
| assertEquals(false, data.getObject(0, 1)); | ||
| assertEquals(false, data.getObject(1, 0)); | ||
| assertEquals(false, data.getObject(1, 1)); | ||
|
|
||
| // Read rows of the tensor in reverse order | ||
| NdArray<Boolean> flippedData = data.slice(Indices.flip(), Indices.flip()); | ||
|
|
||
| assertEquals(false, flippedData.getObject(0, 0)); | ||
| assertEquals(false, flippedData.getObject(0, 1)); | ||
| assertEquals(false, flippedData.getObject(1, 0)); | ||
| assertEquals(true, flippedData.getObject(1, 1)); | ||
|
|
||
| try (EagerSession session = EagerSession.create()) { | ||
| Ops tf = Ops.create(session); | ||
|
|
||
| LogicalNot sub = tf.math.logicalNot(tf.constantOf(tensor)); | ||
| NdArray<Boolean> result = (NdArray<Boolean>) sub.asTensor(); | ||
|
|
||
| assertEquals(false, result.getObject(0, 0)); | ||
| assertEquals(true, result.getObject(0, 1)); | ||
| assertEquals(true, result.getObject(1, 0)); | ||
| assertEquals(true, result.getObject(1, 1)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the same copyright statement as the other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, added