@@ -239,11 +239,12 @@ def partial_fit(self, X, y):
239239 """
240240 return self .fit (X , y )
241241
242- def _predict (self , X ):
242+ def _predict (self , X , batch_size = - 1 ):
243243 if not self ._initialized :
244244 raise NotFittedError ()
245245 self ._graph .add_to_collection ("IS_TRAINING" , False )
246- predict_data_feeder = setup_predict_data_feeder (X )
246+ predict_data_feeder = setup_predict_data_feeder (
247+ X , batch_size = batch_size )
247248 preds = []
248249 dropouts = self ._graph .get_collection (DROPOUTS )
249250 feed_dict = {prob : 1.0 for prob in dropouts }
@@ -254,7 +255,7 @@ def _predict(self, X):
254255 feed_dict ))
255256 return np .concatenate (preds , axis = 0 )
256257
257- def predict (self , X , axis = 1 ):
258+ def predict (self , X , axis = 1 , batch_size = - 1 ):
258259 """Predict class or regression for X.
259260
260261 For a classification model, the predicted class for each sample in X is
@@ -263,27 +264,35 @@ def predict(self, X, axis=1):
263264
264265 Args:
265266 X: array-like matrix, [n_samples, n_features...] or iterator.
267+ axis: Which axis to argmax for classification.
268+ By default axis 1 (next after batch) is used.
269+ Use 2 for sequence predictions.
270+ batch_size: If test set is too big, use batch size to split
271+ it into mini batches. By default full dataset is used.
266272
267273 Returns:
268274 y: array of shape [n_samples]. The predicted classes or predicted
269275 value.
270276 """
271- pred = self ._predict (X )
277+ pred = self ._predict (X , batch_size = batch_size )
272278 if self .n_classes < 2 :
273279 return pred
274280 return pred .argmax (axis = axis )
275281
276- def predict_proba (self , X ):
282+ def predict_proba (self , X , batch_size = - 1 ):
277283 """Predict class probability of the input samples X.
278284
279285 Args:
280286 X: array-like matrix, [n_samples, n_features...] or iterator.
287+ batch_size: If test set is too big, use batch size to split
288+ it into mini batches. By default full dataset is used.
281289
282290 Returns:
283291 y: array of shape [n_samples, n_classes]. The predicted
284292 probabilities for each class.
285- """
286- return self ._predict (X )
293+
294+ """
295+ return self ._predict (X , batch_size = batch_size )
287296
288297 def get_tensor (self , name ):
289298 """Returns tensor by name.
0 commit comments