Skip to content

Commit 069a3ed

Browse files
authored
Merge pull request #236 from xionghhcs/master
fix the tl.utils.predict 's bug. when the data size can not be exactly divided by batch_size, there will be some predict result lost. by @xionghhcs
2 parents 6a8c670 + 7d5e21b commit 069a3ed

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

tensorlayer/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,21 @@ def predict(sess, network, X, x, y_op, batch_size=None):
281281
if result is None:
282282
result = result_a
283283
else:
284-
result = np.hstack((result, result_a))
284+
result = np.vstack((result, result_a))
285+
if result is None:
286+
if len(X) % batch_size != 0:
287+
dp_dict = dict_to_one(network.all_drop)
288+
feed_dict = {x: X[-(len(X) % batch_size):, :], }
289+
feed_dict.update(dp_dict)
290+
result_a = sess.run(y_op, feed_dict=feed_dict)
291+
result = result_a
292+
else:
293+
if len(X) != len(result) and len(X) % batch_size != 0:
294+
dp_dict = dict_to_one(network.all_drop)
295+
feed_dict = {x: X[-(len(X) % batch_size):, :], }
296+
feed_dict.update(dp_dict)
297+
result_a = sess.run(y_op, feed_dict=feed_dict)
298+
result = np.vstack((result, result_a))
285299
return result
286300

287301

0 commit comments

Comments
 (0)