Skip to content

Commit 9645f75

Browse files
author
Dominik Waurenschk
committed
fix tests
1 parent b483e44 commit 9645f75

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

forestci/tests/test_forestci.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_random_forest_error():
2626
for ib in [inbag, None]:
2727
for calibrate in [True, False]:
2828
V_IJ_unbiased = fci.random_forest_error(
29-
forest, X_train, X_test, inbag=ib, calibrate=calibrate
29+
forest, X_train.shape, X_test, inbag=ib, calibrate=calibrate
3030
)
3131
npt.assert_equal(V_IJ_unbiased.shape[0], y_test.shape[0])
3232

@@ -60,7 +60,7 @@ def test_bagging_svr_error():
6060
for ib in [inbag, None]:
6161
for calibrate in [True, False]:
6262
V_IJ_unbiased = fci.random_forest_error(
63-
bagger, X_train, X_test, inbag=ib, calibrate=calibrate
63+
bagger, X_train.shape, X_test, inbag=ib, calibrate=calibrate
6464
)
6565
npt.assert_equal(V_IJ_unbiased.shape[0], y_test.shape[0])
6666

@@ -78,7 +78,7 @@ def test_core_computation():
7878
n_trees = 4
7979

8080
our_vij = fci._core_computation(
81-
X_train_ex, X_test_ex, inbag_ex, pred_centered_ex, n_trees
81+
X_train_ex.shape, X_test_ex, inbag_ex, pred_centered_ex, n_trees
8282
)
8383

8484
r_vij = np.concatenate([np.array([112.5, 387.5]) for _ in range(1000)])
@@ -87,7 +87,7 @@ def test_core_computation():
8787

8888
for mc, ml in zip([True, False], [0.01, None]):
8989
our_vij = fci._core_computation(
90-
X_train_ex,
90+
X_train_ex.shape,
9191
X_test_ex,
9292
inbag_ex,
9393
pred_centered_ex,
@@ -113,7 +113,7 @@ def test_bias_correction():
113113
n_trees = 4
114114

115115
our_vij = fci._core_computation(
116-
X_train_ex, X_test_ex, inbag_ex, pred_centered_ex, n_trees
116+
X_train_ex.shape, X_test_ex, inbag_ex, pred_centered_ex, n_trees
117117
)
118118
our_vij_unbiased = fci._bias_correction(
119119
our_vij, inbag_ex, pred_centered_ex, n_trees
@@ -139,7 +139,7 @@ def test_with_calibration():
139139
n_trees = 4
140140
forest = RandomForestRegressor(n_estimators=n_trees)
141141
forest.fit(X_train, y_train)
142-
V_IJ_unbiased = fci.random_forest_error(forest, X_train, X_test)
142+
V_IJ_unbiased = fci.random_forest_error(forest, X_train.shape, X_test)
143143
npt.assert_equal(V_IJ_unbiased.shape[0], y_test.shape[0])
144144

145145

0 commit comments

Comments
 (0)