@@ -200,7 +200,7 @@ def _bias_correction(V_IJ, inbag, pred_centered, n_trees):
200200 return V_IJ_unbiased
201201
202202
203- def _centered_prediction_forest (forest , X_test ):
203+ def _centered_prediction_forest (forest , X_test , y_output = None ):
204204 """
205205 Center the tree predictions by the mean prediction (forest)
206206
@@ -224,16 +224,20 @@ def _centered_prediction_forest(forest, X_test):
224224 mean prediction (i.e. the prediction of the forest)
225225
226226 """
227- # reformatting required for single sample arrays
228- # caution: assumption that number of features always > 1
227+ # In case the user provided a (n_features)-shaped array for a single sample
228+ # shape it as (1, n_features)
229+ # NOTE: a single-feature set of samples needs to be provided with shape
230+ # (n_samples, 1) or it will be wrongly interpreted!
229231 if len (X_test .shape ) == 1 :
230- # reshape according to the reshaping annotation in scikit-learn
231232 X_test = X_test .reshape (1 , - 1 )
232233
233- pred = np .array ([tree .predict (X_test ) for tree in forest ]).T
234- pred_mean = np .mean (pred , 1 ).reshape (X_test .shape [0 ], 1 )
234+ pred = np .array ([tree .predict (X_test ) for tree in forest ])
235+ if 'n_outputs_' in dir (forest ) and forest .n_outputs_ > 1 :
236+ pred = pred [:,:,y_output ]
235237
236- return pred - pred_mean
238+ pred_mean = np .mean (pred , 0 )
239+
240+ return (pred - pred_mean ).T
237241
238242
239243def random_forest_error (
@@ -244,6 +248,7 @@ def random_forest_error(
244248 calibrate = True ,
245249 memory_constrained = False ,
246250 memory_limit = None ,
251+ y_output = None
247252):
248253 """
249254 Calculate error bars from scikit-learn RandomForest estimators.
@@ -286,6 +291,11 @@ def random_forest_error(
286291 An upper bound for how much memory the itermediate matrices will take
287292 up in Megabytes. This must be provided if memory_constrained=True.
288293
294+ y_output: int, mandatory only for MultiOutput regressor.
295+ In case of MultiOutput regressor, indicate the index of the target to
296+ analyse. The program will return the IJ variance related to that target
297+ only.
298+
289299 Returns
290300 -------
291301 An array with the unbiased sampling variance (V_IJ_unbiased)
@@ -305,10 +315,15 @@ def random_forest_error(
305315 Random Forests: The Jackknife and the Infinitesimal Jackknife", Journal
306316 of Machine Learning Research vol. 15, pp. 1625-1651, 2014.
307317 """
318+
319+ if 'n_outputs_' in dir (forest ) and forest .n_outputs_ > 1 and y_output == None :
320+ e_s = "MultiOutput regressor: specify the index of the target to analyse (y_output)"
321+ raise ValueError (e_s )
322+
308323 if inbag is None :
309324 inbag = calc_inbag (X_train_shape [0 ], forest )
310325
311- pred_centered = _centered_prediction_forest (forest , X_test )
326+ pred_centered = _centered_prediction_forest (forest , X_test , y_output )
312327 n_trees = forest .n_estimators
313328 V_IJ = _core_computation (
314329 X_train_shape , X_test , inbag , pred_centered , n_trees , memory_constrained , memory_limit
@@ -348,6 +363,7 @@ def random_forest_error(
348363 calibrate = False ,
349364 memory_constrained = memory_constrained ,
350365 memory_limit = memory_limit ,
366+ y_output = y_output
351367 )
352368 # Use this second set of variance estimates
353369 # to estimate scale of Monte Carlo noise
0 commit comments