Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit c977e83

Browse files
committed
Adding a batch_size to predict method to split into smaller mini-batches if required
1 parent 7bc6f52 commit c977e83

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

skflow/estimators/base.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,12 @@ def partial_fit(self, X, y):
239239
"""
240240
return self.fit(X, y)
241241

242-
def _predict(self, X):
242+
def _predict(self, X, batch_size=-1):
243243
if not self._initialized:
244244
raise NotFittedError()
245245
self._graph.add_to_collection("IS_TRAINING", False)
246-
predict_data_feeder = setup_predict_data_feeder(X)
246+
predict_data_feeder = setup_predict_data_feeder(
247+
X, batch_size=batch_size)
247248
preds = []
248249
dropouts = self._graph.get_collection(DROPOUTS)
249250
feed_dict = {prob: 1.0 for prob in dropouts}
@@ -254,7 +255,7 @@ def _predict(self, X):
254255
feed_dict))
255256
return np.concatenate(preds, axis=0)
256257

257-
def predict(self, X, axis=1):
258+
def predict(self, X, axis=1, batch_size=-1):
258259
"""Predict class or regression for X.
259260
260261
For a classification model, the predicted class for each sample in X is
@@ -263,27 +264,35 @@ def predict(self, X, axis=1):
263264
264265
Args:
265266
X: array-like matrix, [n_samples, n_features...] or iterator.
267+
axis: Which axis to argmax for classification.
268+
By default axis 1 (next after batch) is used.
269+
Use 2 for sequence predictions.
270+
batch_size: If test set is too big, use batch size to split
271+
it into mini batches. By default full dataset is used.
266272
267273
Returns:
268274
y: array of shape [n_samples]. The predicted classes or predicted
269275
value.
270276
"""
271-
pred = self._predict(X)
277+
pred = self._predict(X, batch_size=batch_size)
272278
if self.n_classes < 2:
273279
return pred
274280
return pred.argmax(axis=axis)
275281

276-
def predict_proba(self, X):
282+
def predict_proba(self, X, batch_size=-1):
277283
"""Predict class probability of the input samples X.
278284
279285
Args:
280286
X: array-like matrix, [n_samples, n_features...] or iterator.
287+
batch_size: If test set is too big, use batch size to split
288+
it into mini batches. By default full dataset is used.
281289
282290
Returns:
283291
y: array of shape [n_samples, n_classes]. The predicted
284292
probabilities for each class.
285-
"""
286-
return self._predict(X)
293+
294+
"""
295+
return self._predict(X, batch_size=batch_size)
287296

288297
def get_tensor(self, name):
289298
"""Returns tensor by name.

0 commit comments

Comments
 (0)