Skip to content

Commit cd7f6ef

Browse files
authored
Merge pull request #20 from vecxoz/dev
Allow n-dimensional input for func API. Some maintenance for tests.
2 parents be81eff + 5ada0fb commit cd7f6ef

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

tests/test_sklearn_api_classification_binary.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def roc_auc_score_universal(y_true, y_pred):
6767
"""
6868
ohe = OneHotEncoder(sparse=False)
6969
y_true = ohe.fit_transform(y_true.reshape(-1, 1))
70+
#@@@@
71+
if len(y_pred.shape) == 1:
72+
y_pred = np.c_[y_pred, y_pred]
73+
y_pred[:, 0] = 1 - y_pred[:, 1]
74+
#@@@@
7075
auc_score = roc_auc_score(y_true, y_pred)
7176
return auc_score
7277

vecstack/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,13 @@ def your_metric(y_true, y_pred):
413413
accept_sparse=['csr'], # allow csr and cast all other sparse types to csr
414414
force_all_finite=False, # allow nan and inf because
415415
# some models (xgboost) can handle
416+
allow_nd=True,
416417
multi_output=False) # do not allow several columns in y_train
417418

418419
if X_test is not None: # allow X_test to be None for mode='oof'
419420
X_test = check_array(X_test,
420421
accept_sparse=['csr'], # allow csr and cast all other sparse types to csr
422+
allow_nd=True,
421423
force_all_finite=False) # allow nan and inf because
422424
# some models (xgboost) can handle
423425
if sample_weight is not None:

0 commit comments

Comments
 (0)