@@ -200,7 +200,7 @@ def _bias_correction(V_IJ, inbag, pred_centered, n_trees):
200200 return V_IJ_unbiased
201201
202202
203- def _centered_prediction_forest (forest , X_test ):
203+ def _centered_prediction_forest (forest , X_test , y_output ):
204204 """
205205 Center the tree predictions by the mean prediction (forest)
206206
@@ -232,6 +232,9 @@ def _centered_prediction_forest(forest, X_test):
232232 X_test = X_test .reshape (1 , - 1 )
233233
234234 pred = np .array ([tree .predict (X_test ) for tree in forest ])
235+ if forest .n_outputs_ > 1 :
236+ pred = pred [:,:,y_output ]
237+
235238 pred_mean = np .mean (pred , 0 )
236239
237240 return (pred - pred_mean ).T
@@ -245,6 +248,7 @@ def random_forest_error(
245248 calibrate = True ,
246249 memory_constrained = False ,
247250 memory_limit = None ,
251+ y_output = None
248252):
249253 """
250254 Calculate error bars from scikit-learn RandomForest estimators.
@@ -287,6 +291,11 @@ def random_forest_error(
287291 An upper bound for how much memory the itermediate matrices will take
288292 up in Megabytes. This must be provided if memory_constrained=True.
289293
294+ y_output: int, mandatory only for MultiOutput regressor.
295+ In case of MultiOutput regressor, indicate the index of the target to
296+ analyse. The program will return the IJ variance related to that target
297+ only.
298+
290299 Returns
291300 -------
292301 An array with the unbiased sampling variance (V_IJ_unbiased)
@@ -306,10 +315,15 @@ def random_forest_error(
306315 Random Forests: The Jackknife and the Infinitesimal Jackknife", Journal
307316 of Machine Learning Research vol. 15, pp. 1625-1651, 2014.
308317 """
318+
319+ if forest .n_outputs_ > 1 and y_output == None :
320+ e_s = "MultiOutput regressor: specify the index of the target to analyse (y_output)"
321+ raise ValueError (e_s )
322+
309323 if inbag is None :
310324 inbag = calc_inbag (X_train_shape [0 ], forest )
311325
312- pred_centered = _centered_prediction_forest (forest , X_test )
326+ pred_centered = _centered_prediction_forest (forest , X_test , y_output )
313327 n_trees = forest .n_estimators
314328 V_IJ = _core_computation (
315329 X_train_shape , X_test , inbag , pred_centered , n_trees , memory_constrained , memory_limit
@@ -349,6 +363,7 @@ def random_forest_error(
349363 calibrate = False ,
350364 memory_constrained = memory_constrained ,
351365 memory_limit = memory_limit ,
366+ y_output = y_output
352367 )
353368 # Use this second set of variance estimates
354369 # to estimate scale of Monte Carlo noise
0 commit comments