1717import static org .junit .jupiter .api .Assertions .assertThrows ;
1818
1919import org .junit .jupiter .api .Test ;
20+ import org .tensorflow .Graph ;
2021import org .tensorflow .Operand ;
22+ import org .tensorflow .Session ;
2123import org .tensorflow .framework .utils .TestSession ;
2224import org .tensorflow .ndarray .Shape ;
25+ import org .tensorflow .ndarray .buffer .DataBuffers ;
2326import org .tensorflow .op .Ops ;
27+ import org .tensorflow .op .core .Placeholder ;
2428import org .tensorflow .types .TFloat32 ;
2529import org .tensorflow .types .TInt32 ;
2630import org .tensorflow .types .TInt64 ;
@@ -36,16 +40,8 @@ public void testAllCorrectUnweighted() {
3640 try (TestSession testSession = TestSession .createTestSession (tfMode )) {
3741 Ops tf = testSession .getTF ();
3842
39- long [] trueArray = {
40- 1L , 0L , 0L ,
41- 0L , 1L , 0L ,
42- 0L , 0L , 1L
43- };
44- float [] predArray = {
45- 1.F , 0.F , 0.F ,
46- 0.F , 1.F , 0.F ,
47- 0.F , 0.F , 1.F
48- };
43+ long [] trueArray = {1L , 0L , 0L , 0L , 1L , 0L , 0L , 0L , 1L };
44+ float [] predArray = {1.F , 0.F , 0.F , 0.F , 1.F , 0.F , 0.F , 0.F , 1.F };
4945 Operand <TInt64 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
5046 Operand <TFloat32 > yPred = tf .reshape (tf .constant (predArray ), tf .constant (Shape .of (3 , 3 )));
5147 CategoricalCrossentropy instance = new CategoricalCrossentropy ();
@@ -55,11 +51,7 @@ public void testAllCorrectUnweighted() {
5551 testSession .evaluate (expected , loss );
5652
5753 // Test with logits.
58- float [] logitsArray = {
59- 10.F , 0.F , 0.F ,
60- 0.F , 10.F , 0.F ,
61- 0.F , 0.F , 10.F
62- };
54+ float [] logitsArray = {10.F , 0.F , 0.F , 0.F , 10.F , 0.F , 0.F , 0.F , 10.F };
6355 yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
6456 Operand <TFloat32 > logits =
6557 tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
@@ -85,11 +77,7 @@ public void testInvalidPredictionsRange() {
8577 Ops tf = testSession .getTF ();
8678 CategoricalCrossentropy instance = new CategoricalCrossentropy ();
8779
88- float [] trueArray = {
89- 1L , 0L , 0L ,
90- 0L , 1L , 0L ,
91- 0L , 0L , 1L
92- };
80+ float [] trueArray = {1L , 0L , 0L , 0L , 1L , 0L , 0L , 0L , 1L };
9381 float [] predArray = {-1.F , 0.F , 0.F , 0.F , 1.F , 0.F , 0.F , 0.F , 1.F };
9482 Operand <TFloat32 > yTrue =
9583 tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
@@ -111,23 +99,15 @@ public void testUnweighted() {
11199 CategoricalCrossentropy instance = new CategoricalCrossentropy ();
112100
113101 int [] trueArray = {1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 1 };
114- float [] predArray = {
115- .9F , .05F , .05F ,
116- .5F , .89F , .6F ,
117- .05F , .01F , .94F
118- };
102+ float [] predArray = {.9F , .05F , .05F , .5F , .89F , .6F , .05F , .01F , .94F };
119103 Operand <TInt32 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
120104 Operand <TFloat32 > yPred = tf .reshape (tf .constant (predArray ), tf .constant (Shape .of (3 , 3 )));
121105 Operand <TFloat32 > loss = instance .call (tf , yTrue , yPred );
122106 float expected = 0.32396814F ;
123107 testSession .evaluate (expected , loss );
124108
125109 // Test with logits.
126- float [] logitsArray = {
127- 8.F , 1.F , 1.F ,
128- 0.F , 9.F , 1.F ,
129- 2.F , 3.F , 5.F
130- };
110+ float [] logitsArray = {8.F , 1.F , 1.F , 0.F , 9.F , 1.F , 2.F , 3.F , 5.F };
131111 Operand <TFloat32 > logits =
132112 tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
133113 instance = new CategoricalCrossentropy (true );
@@ -145,16 +125,8 @@ public void testScalarWeighted() {
145125 try (TestSession testSession = TestSession .createTestSession (tfMode )) {
146126 Ops tf = testSession .getTF ();
147127
148- int [] trueArray = {
149- 1 , 0 , 0 ,
150- 0 , 1 , 0 ,
151- 0 , 0 , 1
152- };
153- float [] predArray = {
154- .9F , .05F , .05F ,
155- .5F , .89F , .6F ,
156- .05F , .01F , .94F
157- };
128+ int [] trueArray = {1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 1 };
129+ float [] predArray = {.9F , .05F , .05F , .5F , .89F , .6F , .05F , .01F , .94F };
158130 Operand <TInt32 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
159131 Operand <TFloat32 > yPred = tf .reshape (tf .constant (predArray ), tf .constant (Shape .of (3 , 3 )));
160132 Operand <TFloat32 > sampleWeight = tf .constant (2.3F );
@@ -166,11 +138,7 @@ public void testScalarWeighted() {
166138 testSession .evaluate (expected , loss );
167139
168140 // Test with logits.
169- float [] logitsArray = {
170- 8.F , 1.F , 1.F ,
171- 0.F , 9.F , 1.F ,
172- 2.F , 3.F , 5.F
173- };
141+ float [] logitsArray = {8.F , 1.F , 1.F , 0.F , 9.F , 1.F , 2.F , 3.F , 5.F };
174142 Operand <TFloat32 > logits =
175143 tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
176144 instance = new CategoricalCrossentropy (true );
@@ -189,16 +157,8 @@ public void testSsampleWeighted() {
189157 CategoricalCrossentropy instance = new CategoricalCrossentropy ();
190158
191159 float [] sampeWeightArray = {1.2F , 3.4F , 5.6F };
192- int [] trueArray = {
193- 1 , 0 , 0 ,
194- 0 , 1 , 0 ,
195- 0 , 0 , 1
196- };
197- float [] predArray = {
198- .9F , .05F , .05F ,
199- .5F , .89F , .6F ,
200- .05F , .01F , .94F
201- };
160+ int [] trueArray = {1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 1 };
161+ float [] predArray = {.9F , .05F , .05F , .5F , .89F , .6F , .05F , .01F , .94F };
202162 Operand <TInt32 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
203163 Operand <TFloat32 > yPred = tf .reshape (tf .constant (predArray ), tf .constant (Shape .of (3 , 3 )));
204164 Operand <TFloat32 > sampleWeight =
@@ -208,11 +168,7 @@ public void testSsampleWeighted() {
208168 testSession .evaluate (expected , loss );
209169
210170 // Test with logits.
211- float [] logitsArray = {
212- 8.F , 1.F , 1.F ,
213- 0.F , 9.F , 1.F ,
214- 2.F , 3.F , 5.F
215- };
171+ float [] logitsArray = {8.F , 1.F , 1.F , 0.F , 9.F , 1.F , 2.F , 3.F , 5.F };
216172 Operand <TFloat32 > logits =
217173 tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
218174 instance = new CategoricalCrossentropy (true );
@@ -231,11 +187,7 @@ public void testNoReduction() {
231187
232188 // Test with logits.
233189 int [] trueArray = {1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 1 };
234- float [] logitsArray = {
235- 8.F , 1.F , 1.F ,
236- 0.F , 9.F , 1.F ,
237- 2.F , 3.F , 5.F
238- };
190+ float [] logitsArray = {8.F , 1.F , 1.F , 0.F , 9.F , 1.F , 2.F , 3.F , 5.F };
239191 Operand <TInt32 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
240192 Operand <TFloat32 > logits =
241193 tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
@@ -266,4 +218,34 @@ public void testLabelSmoothing() {
266218 testSession .evaluate (expected , loss );
267219 }
268220 }
221+
222+ @ Test
223+ public void testCategoricalCrossEntopyWithDynamicBatchSize () {
224+ try (Graph graph = new Graph ()) {
225+ Ops tf = Ops .create (graph );
226+ Operand yPred = tf .placeholder (TFloat32 .class , Placeholder .shape (Shape .of (-1 , 3 )));
227+ Operand yTrue =
228+ tf .reshape (tf .constant (new float [] {1f , 0f , 0f , 0f , 1f , 0f , 0f , 0f , 1f }), tf .array (3 , 3 ));
229+ CategoricalCrossentropy instance = new CategoricalCrossentropy (true );
230+ Operand loss =
231+ instance .call (tf , yTrue , yPred ); // Throw TFInvalidArgument Exception without fix
232+ try (Session session = new Session (graph );
233+ TFloat32 result =
234+ (TFloat32 )
235+ session
236+ .runner ()
237+ .feed (
238+ yPred ,
239+ TFloat32 .tensorOf (
240+ Shape .of (3 , 3 ),
241+ DataBuffers .of (
242+ new float [] {1.f , 0.f , 0.f , 0.f , 1.f , 0.f , 0.f , 0.f , 1.f })))
243+ .fetch (loss )
244+ .run ()
245+ .get (0 )) {
246+ if (Math .abs (0.5514477f - result .getFloat ()) > 0.01 )
247+ throw new IllegalStateException ("Invalid result :" + result .getFloat ());
248+ }
249+ }
250+ }
269251}
0 commit comments