Skip to content

Commit 6ee5fcf

Browse files
committed
implement option y_output to handle MultiOutput estimators
1 parent 2b05a96 commit 6ee5fcf

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

forestci/forestci.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _bias_correction(V_IJ, inbag, pred_centered, n_trees):
200200
return V_IJ_unbiased
201201

202202

203-
def _centered_prediction_forest(forest, X_test):
203+
def _centered_prediction_forest(forest, X_test, y_output):
204204
"""
205205
Center the tree predictions by the mean prediction (forest)
206206
@@ -232,6 +232,9 @@ def _centered_prediction_forest(forest, X_test):
232232
X_test = X_test.reshape(1, -1)
233233

234234
pred = np.array([tree.predict(X_test) for tree in forest])
235+
if forest.n_outputs_ > 1:
236+
pred = pred[:,:,y_output]
237+
235238
pred_mean = np.mean(pred, 0)
236239

237240
return (pred - pred_mean).T
@@ -245,6 +248,7 @@ def random_forest_error(
245248
calibrate=True,
246249
memory_constrained=False,
247250
memory_limit=None,
251+
y_output=None
248252
):
249253
"""
250254
Calculate error bars from scikit-learn RandomForest estimators.
@@ -287,6 +291,11 @@ def random_forest_error(
287291
An upper bound for how much memory the itermediate matrices will take
288292
up in Megabytes. This must be provided if memory_constrained=True.
289293
294+
y_output: int, mandatory only for MultiOutput regressor.
295+
In case of MultiOutput regressor, indicate the index of the target to
296+
analyse. The program will return the IJ variance related to that target
297+
only.
298+
290299
Returns
291300
-------
292301
An array with the unbiased sampling variance (V_IJ_unbiased)
@@ -306,10 +315,15 @@ def random_forest_error(
306315
Random Forests: The Jackknife and the Infinitesimal Jackknife", Journal
307316
of Machine Learning Research vol. 15, pp. 1625-1651, 2014.
308317
"""
318+
319+
if forest.n_outputs_ > 1 and y_output == None:
320+
e_s = "MultiOutput regressor: specify the index of the target to analyse (y_output)"
321+
raise ValueError(e_s)
322+
309323
if inbag is None:
310324
inbag = calc_inbag(X_train_shape[0], forest)
311325

312-
pred_centered = _centered_prediction_forest(forest, X_test)
326+
pred_centered = _centered_prediction_forest(forest, X_test, y_output)
313327
n_trees = forest.n_estimators
314328
V_IJ = _core_computation(
315329
X_train_shape, X_test, inbag, pred_centered, n_trees, memory_constrained, memory_limit
@@ -349,6 +363,7 @@ def random_forest_error(
349363
calibrate=False,
350364
memory_constrained=memory_constrained,
351365
memory_limit=memory_limit,
366+
y_output=y_output
352367
)
353368
# Use this second set of variance estimates
354369
# to estimate scale of Monte Carlo noise

0 commit comments

Comments
 (0)