Skip to content

Commit 17152c1

Browse files
authored
Merge pull request #111 from DasCapschen/pass_only_shape
pass only shape for X_train
2 parents 022f6d4 + d91f0bb commit 17152c1

File tree

5 files changed

+20
-21
lines changed

5 files changed

+20
-21
lines changed

examples/plot_mpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
plt.show()
5252

5353
# Calculate the variance
54-
mpg_V_IJ_unbiased = fci.random_forest_error(mpg_forest, mpg_X_train,
54+
mpg_V_IJ_unbiased = fci.random_forest_error(mpg_forest, mpg_X_train.shape,
5555
mpg_X_test)
5656

5757
# Plot error bars for predicted MPG using unbiased variance

examples/plot_mpg_svr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
plt.show()
5353

5454
# Calculate the variance
55-
mpg_V_IJ_unbiased = fci.random_forest_error(mpg_bagger, mpg_X_train, mpg_X_test)
55+
mpg_V_IJ_unbiased = fci.random_forest_error(mpg_bagger, mpg_X_train.shape, mpg_X_test)
5656

5757
# Plot error bars for predicted MPG using unbiased variance
5858
plt.errorbar(mpg_y_test, mpg_y_hat, yerr=np.sqrt(mpg_V_IJ_unbiased), fmt="o")

examples/plot_spam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
plt.legend()
4545

4646
# Calculate the variance
47-
spam_V_IJ_unbiased = fci.random_forest_error(spam_RFC, spam_X_train,
47+
spam_V_IJ_unbiased = fci.random_forest_error(spam_RFC, spam_X_train.shape,
4848
spam_X_test)
4949

5050
# Plot forest prediction for emails and standard deviation for estimates

forestci/forestci.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def calc_inbag(n_samples, forest):
9090

9191

9292
def _core_computation(
93-
X_train,
93+
X_train_shape,
9494
X_test,
9595
inbag,
9696
pred_centered,
@@ -104,8 +104,8 @@ def _core_computation(
104104
105105
Parameters
106106
----------
107-
X_train : ndarray
108-
An array with shape (n_train_sample, n_features).
107+
X_train_shape : tuple (int, int)
108+
Shape (n_train_sample, n_features).
109109
110110
X_test : ndarray
111111
An array with shape (n_test_sample, n_features).
@@ -140,10 +140,10 @@ def _core_computation(
140140
raise ValueError("If memory_constrained=True, must provide", "memory_limit.")
141141

142142
# Assumes double precision float
143-
chunk_size = int((memory_limit * 1e6) / (8.0 * X_train.shape[0]))
143+
chunk_size = int((memory_limit * 1e6) / (8.0 * X_train_shape[0]))
144144

145145
if chunk_size == 0:
146-
min_limit = 8.0 * X_train.shape[0] / 1e6
146+
min_limit = 8.0 * X_train_shape[0] / 1e6
147147
raise ValueError(
148148
"memory_limit provided is too small."
149149
+ "For these dimensions, memory_limit must "
@@ -238,7 +238,7 @@ def _centered_prediction_forest(forest, X_test):
238238

239239
def random_forest_error(
240240
forest,
241-
X_train,
241+
X_train_shape,
242242
X_test,
243243
inbag=None,
244244
calibrate=True,
@@ -256,9 +256,8 @@ def random_forest_error(
256256
forest : RandomForest
257257
Regressor or Classifier object.
258258
259-
X_train : ndarray
260-
An array with shape (n_train_sample, n_features). The design matrix for
261-
training data.
259+
X_train_shape : tuple (int, int)
260+
Shape (n_train_sample, n_features) of the design matrix for training data.
262261
263262
X_test : ndarray
264263
An array with shape (n_test_sample, n_features). The design matrix
@@ -307,12 +306,12 @@ def random_forest_error(
307306
of Machine Learning Research vol. 15, pp. 1625-1651, 2014.
308307
"""
309308
if inbag is None:
310-
inbag = calc_inbag(X_train.shape[0], forest)
309+
inbag = calc_inbag(X_train_shape[0], forest)
311310

312311
pred_centered = _centered_prediction_forest(forest, X_test)
313312
n_trees = forest.n_estimators
314313
V_IJ = _core_computation(
315-
X_train, X_test, inbag, pred_centered, n_trees, memory_constrained, memory_limit
314+
X_train_shape, X_test, inbag, pred_centered, n_trees, memory_constrained, memory_limit
316315
)
317316
V_IJ_unbiased = _bias_correction(V_IJ, inbag, pred_centered, n_trees)
318317

@@ -344,7 +343,7 @@ def random_forest_error(
344343

345344
results_ss = random_forest_error(
346345
new_forest,
347-
X_train,
346+
X_train_shape,
348347
X_test,
349348
calibrate=False,
350349
memory_constrained=memory_constrained,

forestci/tests/test_forestci.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_random_forest_error():
2626
for ib in [inbag, None]:
2727
for calibrate in [True, False]:
2828
V_IJ_unbiased = fci.random_forest_error(
29-
forest, X_train, X_test, inbag=ib, calibrate=calibrate
29+
forest, X_train.shape, X_test, inbag=ib, calibrate=calibrate
3030
)
3131
npt.assert_equal(V_IJ_unbiased.shape[0], y_test.shape[0])
3232

@@ -60,7 +60,7 @@ def test_bagging_svr_error():
6060
for ib in [inbag, None]:
6161
for calibrate in [True, False]:
6262
V_IJ_unbiased = fci.random_forest_error(
63-
bagger, X_train, X_test, inbag=ib, calibrate=calibrate
63+
bagger, X_train.shape, X_test, inbag=ib, calibrate=calibrate
6464
)
6565
npt.assert_equal(V_IJ_unbiased.shape[0], y_test.shape[0])
6666

@@ -78,7 +78,7 @@ def test_core_computation():
7878
n_trees = 4
7979

8080
our_vij = fci._core_computation(
81-
X_train_ex, X_test_ex, inbag_ex, pred_centered_ex, n_trees
81+
X_train_ex.shape, X_test_ex, inbag_ex, pred_centered_ex, n_trees
8282
)
8383

8484
r_vij = np.concatenate([np.array([112.5, 387.5]) for _ in range(1000)])
@@ -87,7 +87,7 @@ def test_core_computation():
8787

8888
for mc, ml in zip([True, False], [0.01, None]):
8989
our_vij = fci._core_computation(
90-
X_train_ex,
90+
X_train_ex.shape,
9191
X_test_ex,
9292
inbag_ex,
9393
pred_centered_ex,
@@ -113,7 +113,7 @@ def test_bias_correction():
113113
n_trees = 4
114114

115115
our_vij = fci._core_computation(
116-
X_train_ex, X_test_ex, inbag_ex, pred_centered_ex, n_trees
116+
X_train_ex.shape, X_test_ex, inbag_ex, pred_centered_ex, n_trees
117117
)
118118
our_vij_unbiased = fci._bias_correction(
119119
our_vij, inbag_ex, pred_centered_ex, n_trees
@@ -139,7 +139,7 @@ def test_with_calibration():
139139
n_trees = 4
140140
forest = RandomForestRegressor(n_estimators=n_trees)
141141
forest.fit(X_train, y_train)
142-
V_IJ_unbiased = fci.random_forest_error(forest, X_train, X_test)
142+
V_IJ_unbiased = fci.random_forest_error(forest, X_train.shape, X_test)
143143
npt.assert_equal(V_IJ_unbiased.shape[0], y_test.shape[0])
144144

145145

0 commit comments

Comments
 (0)