Skip to content

Commit f14033b

Browse files
DOC fix AttributeError in Sequential object - predict_proba (#993)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 2b126e3 commit f14033b

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

examples/applications/porto_seguro_keras_under_sampling.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,23 @@ def wrapper(*args, **kwds):
149149
###############################################################################
150150
# The first model will be trained using the ``fit`` method and with imbalanced
151151
# mini-batches.
152-
152+
import tensorflow
153153
from sklearn.metrics import roc_auc_score
154+
from sklearn.utils import parse_version
155+
156+
tf_version = parse_version(tensorflow.__version__)
154157

155158

156159
@timeit
157160
def fit_predict_imbalanced_model(X_train, y_train, X_test, y_test):
158161
model = make_model(X_train.shape[1])
159162
model.fit(X_train, y_train, epochs=2, verbose=1, batch_size=1000)
160-
y_pred = model.predict_proba(X_test, batch_size=1000)
163+
if tf_version < parse_version("2.6"):
164+
# predict_proba was removed in tensorflow 2.6
165+
predict_method = "predict_proba"
166+
else:
167+
predict_method = "predict"
168+
y_pred = getattr(model, predict_method)(X_test, batch_size=1000)
161169
return roc_auc_score(y_test, y_pred)
162170

163171

0 commit comments

Comments
 (0)