Skip to content

Commit b483e44

Browse files
author
Dominik Waurenschk
committed
pass only shape for X_train
1 parent 022f6d4 commit b483e44

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

forestci/forestci.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def calc_inbag(n_samples, forest):
9090

9191

9292
def _core_computation(
93-
X_train,
93+
X_train_shape,
9494
X_test,
9595
inbag,
9696
pred_centered,
@@ -104,8 +104,8 @@ def _core_computation(
104104
105105
Parameters
106106
----------
107-
X_train : ndarray
108-
An array with shape (n_train_sample, n_features).
107+
X_train_shape : tuple (int, int)
108+
Shape (n_train_sample, n_features).
109109
110110
X_test : ndarray
111111
An array with shape (n_test_sample, n_features).
@@ -140,10 +140,10 @@ def _core_computation(
140140
raise ValueError("If memory_constrained=True, must provide", "memory_limit.")
141141

142142
# Assumes double precision float
143-
chunk_size = int((memory_limit * 1e6) / (8.0 * X_train.shape[0]))
143+
chunk_size = int((memory_limit * 1e6) / (8.0 * X_train_shape[0]))
144144

145145
if chunk_size == 0:
146-
min_limit = 8.0 * X_train.shape[0] / 1e6
146+
min_limit = 8.0 * X_train_shape[0] / 1e6
147147
raise ValueError(
148148
"memory_limit provided is too small."
149149
+ "For these dimensions, memory_limit must "
@@ -238,7 +238,7 @@ def _centered_prediction_forest(forest, X_test):
238238

239239
def random_forest_error(
240240
forest,
241-
X_train,
241+
X_train_shape,
242242
X_test,
243243
inbag=None,
244244
calibrate=True,
@@ -256,9 +256,8 @@ def random_forest_error(
256256
forest : RandomForest
257257
Regressor or Classifier object.
258258
259-
X_train : ndarray
260-
An array with shape (n_train_sample, n_features). The design matrix for
261-
training data.
259+
X_train_shape : tuple (int, int)
260+
Shape (n_train_sample, n_features) of the design matrix for training data.
262261
263262
X_test : ndarray
264263
An array with shape (n_test_sample, n_features). The design matrix
@@ -307,12 +306,12 @@ def random_forest_error(
307306
of Machine Learning Research vol. 15, pp. 1625-1651, 2014.
308307
"""
309308
if inbag is None:
310-
inbag = calc_inbag(X_train.shape[0], forest)
309+
inbag = calc_inbag(X_train_shape[0], forest)
311310

312311
pred_centered = _centered_prediction_forest(forest, X_test)
313312
n_trees = forest.n_estimators
314313
V_IJ = _core_computation(
315-
X_train, X_test, inbag, pred_centered, n_trees, memory_constrained, memory_limit
314+
X_train_shape, X_test, inbag, pred_centered, n_trees, memory_constrained, memory_limit
316315
)
317316
V_IJ_unbiased = _bias_correction(V_IJ, inbag, pred_centered, n_trees)
318317

@@ -344,7 +343,7 @@ def random_forest_error(
344343

345344
results_ss = random_forest_error(
346345
new_forest,
347-
X_train,
346+
X_train_shape,
348347
X_test,
349348
calibrate=False,
350349
memory_constrained=memory_constrained,

0 commit comments

Comments
 (0)