- 
                Notifications
    
You must be signed in to change notification settings  - Fork 49
 
Open
Description
I was trying to run the Plotting Classification Forest Error Bars. While the first plot is correctly produced, there is TypeError before the scatter plot can be rendered.
INPUT: (Copy-pasted the example code from the documentation)
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import forestci as fci
from sklearn.datasets import make_classification
spam_X, spam_y = make_classification(5000)
# split the datainto training and test set
spam_X_train, spam_X_test, spam_y_train, spam_y_test = train_test_split(
                                                       spam_X, spam_y,
                                                       test_size=0.2)
# create RandomForestClassifier
n_trees = 500
spam_RFC = RandomForestClassifier(max_features=5, n_estimators=n_trees,
                                  random_state=42)
spam_RFC.fit(spam_X_train, spam_y_train)
spam_y_hat = spam_RFC.predict_proba(spam_X_test)
idx_spam = np.where(spam_y_test == 1)[0]
idx_ham = np.where(spam_y_test == 0)[0]
# Histogram predictions without error bars:
fig, ax = plt.subplots(1)
ax.hist(spam_y_hat[idx_spam, 1], histtype='step', label='spam')
ax.hist(spam_y_hat[idx_ham, 1], histtype='step', label='not spam')
ax.set_xlabel('Prediction (spam probability)')
ax.set_ylabel('Number of observations')
plt.legend()
# Calculate the variance
spam_V_IJ_unbiased = fci.random_forest_error(spam_RFC, spam_X_train,
                                             spam_X_test)
# Plot forest prediction for emails and standard deviation for estimates
# Blue points are spam emails; Green points are non-spam emails
fig, ax = plt.subplots(1)
ax.scatter(spam_y_hat[idx_spam, 1],
           np.sqrt(spam_V_IJ_unbiased[idx_spam]),
           label='spam')
ax.scatter(spam_y_hat[idx_ham, 1],
           np.sqrt(spam_V_IJ_unbiased[idx_ham]),
           label='not spam')
ax.set_xlabel('Prediction (spam probability)')
ax.set_ylabel('Standard deviation')
plt.legend()
plt.show()OUTPUT:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[23], [line 34](vscode-notebook-cell:?execution_count=23&line=34)
     [31](vscode-notebook-cell:?execution_count=23&line=31) plt.legend()
     [33](vscode-notebook-cell:?execution_count=23&line=33) # Calculate the variance
---> [34](vscode-notebook-cell:?execution_count=23&line=34) spam_V_IJ_unbiased = fci.random_forest_error(spam_RFC, spam_X_train,
     [35](vscode-notebook-cell:?execution_count=23&line=35)                                              spam_X_test)
     [37](vscode-notebook-cell:?execution_count=23&line=37) # Plot forest prediction for emails and standard deviation for estimates
     [38](vscode-notebook-cell:?execution_count=23&line=38) # Blue points are spam emails; Green points are non-spam emails
     [39](vscode-notebook-cell:?execution_count=23&line=39) fig, ax = plt.subplots(1)
File ~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:324, in random_forest_error(forest, X_train_shape, X_test, inbag, calibrate, memory_constrained, memory_limit, y_output)
    [321](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:321)     raise ValueError(e_s)
    [323](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:323) if inbag is None:
--> [324](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:324)     inbag = calc_inbag(X_train_shape[0], forest)
    [326](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:326) pred_centered = _centered_prediction_forest(forest, X_test, y_output)
    [327](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:327) n_trees = forest.n_estimators
File ~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:70, in calc_inbag(n_samples, forest)
     [67](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:67)     raise ValueError(e_s)
     [69](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:69) n_trees = forest.n_estimators
---> [70](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:70) inbag = np.zeros((n_samples, n_trees))
     [71](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:71) sample_idx = []
     [72](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:72) if isinstance(forest, BaseForest):
TypeError: only integer scalar arrays can be converted to a scalar index
Metadata
Metadata
Assignees
Labels
No labels