Skip to content

Commit 2b05a96

Browse files
committed
simplify traspose and reshape
1 parent d170994 commit 2b05a96

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

forestci/forestci.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,10 @@ def _centered_prediction_forest(forest, X_test):
231231
if len(X_test.shape) == 1:
232232
X_test = X_test.reshape(1, -1)
233233

234-
pred = np.array([tree.predict(X_test) for tree in forest]).T
235-
pred_mean = np.mean(pred, 1).reshape(X_test.shape[0], 1)
234+
pred = np.array([tree.predict(X_test) for tree in forest])
235+
pred_mean = np.mean(pred, 0)
236236

237-
return pred - pred_mean
237+
return (pred - pred_mean).T
238238

239239

240240
def random_forest_error(

0 commit comments

Comments
 (0)