Skip to content

Commit ffa7227

Browse files
authored
Merge pull request #113 from danieleongari/handle_multioutput_model
Handle MultiOutput model
2 parents 17152c1 + 4c35c09 commit ffa7227

File tree

2 files changed

+61
-8
lines changed

2 files changed

+61
-8
lines changed

forestci/forestci.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _bias_correction(V_IJ, inbag, pred_centered, n_trees):
200200
return V_IJ_unbiased
201201

202202

203-
def _centered_prediction_forest(forest, X_test):
203+
def _centered_prediction_forest(forest, X_test, y_output=None):
204204
"""
205205
Center the tree predictions by the mean prediction (forest)
206206
@@ -224,16 +224,20 @@ def _centered_prediction_forest(forest, X_test):
224224
mean prediction (i.e. the prediction of the forest)
225225
226226
"""
227-
# reformatting required for single sample arrays
228-
# caution: assumption that number of features always > 1
227+
# In case the user provided a (n_features)-shaped array for a single sample
228+
# shape it as (1, n_features)
229+
# NOTE: a single-feature set of samples needs to be provided with shape
230+
# (n_samples, 1) or it will be wrongly interpreted!
229231
if len(X_test.shape) == 1:
230-
# reshape according to the reshaping annotation in scikit-learn
231232
X_test = X_test.reshape(1, -1)
232233

233-
pred = np.array([tree.predict(X_test) for tree in forest]).T
234-
pred_mean = np.mean(pred, 1).reshape(X_test.shape[0], 1)
234+
pred = np.array([tree.predict(X_test) for tree in forest])
235+
if 'n_outputs_' in dir(forest) and forest.n_outputs_ > 1:
236+
pred = pred[:,:,y_output]
235237

236-
return pred - pred_mean
238+
pred_mean = np.mean(pred, 0)
239+
240+
return (pred - pred_mean).T
237241

238242

239243
def random_forest_error(
@@ -244,6 +248,7 @@ def random_forest_error(
244248
calibrate=True,
245249
memory_constrained=False,
246250
memory_limit=None,
251+
y_output=None
247252
):
248253
"""
249254
Calculate error bars from scikit-learn RandomForest estimators.
@@ -286,6 +291,11 @@ def random_forest_error(
286291
An upper bound for how much memory the itermediate matrices will take
287292
up in Megabytes. This must be provided if memory_constrained=True.
288293
294+
y_output: int, mandatory only for MultiOutput regressor.
295+
In case of MultiOutput regressor, indicate the index of the target to
296+
analyse. The program will return the IJ variance related to that target
297+
only.
298+
289299
Returns
290300
-------
291301
An array with the unbiased sampling variance (V_IJ_unbiased)
@@ -305,10 +315,15 @@ def random_forest_error(
305315
Random Forests: The Jackknife and the Infinitesimal Jackknife", Journal
306316
of Machine Learning Research vol. 15, pp. 1625-1651, 2014.
307317
"""
318+
319+
if 'n_outputs_' in dir(forest) and forest.n_outputs_ > 1 and y_output == None:
320+
e_s = "MultiOutput regressor: specify the index of the target to analyse (y_output)"
321+
raise ValueError(e_s)
322+
308323
if inbag is None:
309324
inbag = calc_inbag(X_train_shape[0], forest)
310325

311-
pred_centered = _centered_prediction_forest(forest, X_test)
326+
pred_centered = _centered_prediction_forest(forest, X_test, y_output)
312327
n_trees = forest.n_estimators
313328
V_IJ = _core_computation(
314329
X_train_shape, X_test, inbag, pred_centered, n_trees, memory_constrained, memory_limit
@@ -348,6 +363,7 @@ def random_forest_error(
348363
calibrate=False,
349364
memory_constrained=memory_constrained,
350365
memory_limit=memory_limit,
366+
y_output=y_output
351367
)
352368
# Use this second set of variance estimates
353369
# to estimate scale of Monte Carlo noise

forestci/tests/test_forestci.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,43 @@ def test_random_forest_error():
4040
)
4141

4242

43+
def test_random_forest_error_multioutput():
44+
X = np.array([[5, 2], [5, 5], [3, 3], [6, 4], [6, 6]])
45+
46+
y = np.array([[70, 37], [100, 55], [60, 33], [100,54], [120, 66]])
47+
48+
train_idx = [2, 3, 4]
49+
test_idx = [0, 1]
50+
51+
y_test = y[test_idx]
52+
y_train = y[train_idx]
53+
X_test = X[test_idx]
54+
X_train = X[train_idx]
55+
56+
n_trees = 4
57+
forest = RandomForestRegressor(n_estimators=n_trees)
58+
forest.fit(X_train, y_train)
59+
60+
V_IJ_unbiased_target0 = fci.random_forest_error(
61+
forest, X_train.shape, X_test, calibrate=True, y_output=0
62+
)
63+
npt.assert_equal(V_IJ_unbiased_target0.shape[0], y_test.shape[0])
64+
65+
# With a MultiOutput RandomForestRegressor the user MUST specify a y_output
66+
npt.assert_raises(
67+
ValueError,
68+
fci.random_forest_error,
69+
forest,
70+
X_train.shape,
71+
X_test,
72+
inbag=None,
73+
calibrate=True,
74+
memory_constrained=False,
75+
memory_limit=None,
76+
y_output=None # This should trigger the ValueError
77+
)
78+
79+
4380
def test_bagging_svr_error():
4481
X = np.array([[5, 2], [5, 5], [3, 3], [6, 4], [6, 6]])
4582

0 commit comments

Comments
 (0)