Skip to content

Commit 4c35c09

Browse files
committed
add testing for multi-output error estimation
1 parent 1c46ff4 commit 4c35c09

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

forestci/tests/test_forestci.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,43 @@ def test_random_forest_error():
4040
)
4141

4242

43+
def test_random_forest_error_multioutput():
44+
X = np.array([[5, 2], [5, 5], [3, 3], [6, 4], [6, 6]])
45+
46+
y = np.array([[70, 37], [100, 55], [60, 33], [100,54], [120, 66]])
47+
48+
train_idx = [2, 3, 4]
49+
test_idx = [0, 1]
50+
51+
y_test = y[test_idx]
52+
y_train = y[train_idx]
53+
X_test = X[test_idx]
54+
X_train = X[train_idx]
55+
56+
n_trees = 4
57+
forest = RandomForestRegressor(n_estimators=n_trees)
58+
forest.fit(X_train, y_train)
59+
60+
V_IJ_unbiased_target0 = fci.random_forest_error(
61+
forest, X_train.shape, X_test, calibrate=True, y_output=0
62+
)
63+
npt.assert_equal(V_IJ_unbiased_target0.shape[0], y_test.shape[0])
64+
65+
# With a MultiOutput RandomForestRegressor the user MUST specify a y_output
66+
npt.assert_raises(
67+
ValueError,
68+
fci.random_forest_error,
69+
forest,
70+
X_train.shape,
71+
X_test,
72+
inbag=None,
73+
calibrate=True,
74+
memory_constrained=False,
75+
memory_limit=None,
76+
y_output=None # This should trigger the ValueError
77+
)
78+
79+
4380
def test_bagging_svr_error():
4481
X = np.array([[5, 2], [5, 5], [3, 3], [6, 4], [6, 6]])
4582

0 commit comments

Comments
 (0)