Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit 0c57c03

Browse files
committed
Adding support to split numpy arrays into batches for prediction
1 parent c977e83 commit 0c57c03

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

skflow/io/data_feeder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import itertools
2020
import six
2121
from six.moves import xrange # pylint: disable=redefined-builtin
22+
import math
2223

2324
import numpy as np
2425
from sklearn.utils import check_array
@@ -117,6 +118,9 @@ def setup_predict_data_feeder(X, batch_size=-1):
117118
return _batch_data(X, batch_size)
118119
if len(X.shape) == 1:
119120
X = np.reshape(X, (-1, 1))
121+
if batch_size > 0:
122+
n_batches = int(math.ceil(float(len(X)) / batch_size))
123+
return [X[i * batch_size:(i + 1) * batch_size] for i in xrange(n_batches)]
120124
return [X]
121125

122126

0 commit comments

Comments
 (0)