Skip to content

TypeError when running example code (Plotting Classification Forest Error Bars) #119

@FlyingWorkshop

Description

@FlyingWorkshop

I was trying to run the Plotting Classification Forest Error Bars. While the first plot is correctly produced, there is TypeError before the scatter plot can be rendered.

INPUT: (Copy-pasted the example code from the documentation)

import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import forestci as fci
from sklearn.datasets import make_classification

spam_X, spam_y = make_classification(5000)

# split the datainto training and test set
spam_X_train, spam_X_test, spam_y_train, spam_y_test = train_test_split(
                                                       spam_X, spam_y,
                                                       test_size=0.2)

# create RandomForestClassifier
n_trees = 500
spam_RFC = RandomForestClassifier(max_features=5, n_estimators=n_trees,
                                  random_state=42)
spam_RFC.fit(spam_X_train, spam_y_train)
spam_y_hat = spam_RFC.predict_proba(spam_X_test)

idx_spam = np.where(spam_y_test == 1)[0]
idx_ham = np.where(spam_y_test == 0)[0]

# Histogram predictions without error bars:
fig, ax = plt.subplots(1)
ax.hist(spam_y_hat[idx_spam, 1], histtype='step', label='spam')
ax.hist(spam_y_hat[idx_ham, 1], histtype='step', label='not spam')
ax.set_xlabel('Prediction (spam probability)')
ax.set_ylabel('Number of observations')
plt.legend()

# Calculate the variance
spam_V_IJ_unbiased = fci.random_forest_error(spam_RFC, spam_X_train,
                                             spam_X_test)

# Plot forest prediction for emails and standard deviation for estimates
# Blue points are spam emails; Green points are non-spam emails
fig, ax = plt.subplots(1)
ax.scatter(spam_y_hat[idx_spam, 1],
           np.sqrt(spam_V_IJ_unbiased[idx_spam]),
           label='spam')

ax.scatter(spam_y_hat[idx_ham, 1],
           np.sqrt(spam_V_IJ_unbiased[idx_ham]),
           label='not spam')

ax.set_xlabel('Prediction (spam probability)')
ax.set_ylabel('Standard deviation')
plt.legend()
plt.show()

OUTPUT:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[23], [line 34](vscode-notebook-cell:?execution_count=23&line=34)
     [31](vscode-notebook-cell:?execution_count=23&line=31) plt.legend()
     [33](vscode-notebook-cell:?execution_count=23&line=33) # Calculate the variance
---> [34](vscode-notebook-cell:?execution_count=23&line=34) spam_V_IJ_unbiased = fci.random_forest_error(spam_RFC, spam_X_train,
     [35](vscode-notebook-cell:?execution_count=23&line=35)                                              spam_X_test)
     [37](vscode-notebook-cell:?execution_count=23&line=37) # Plot forest prediction for emails and standard deviation for estimates
     [38](vscode-notebook-cell:?execution_count=23&line=38) # Blue points are spam emails; Green points are non-spam emails
     [39](vscode-notebook-cell:?execution_count=23&line=39) fig, ax = plt.subplots(1)

File ~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:324, in random_forest_error(forest, X_train_shape, X_test, inbag, calibrate, memory_constrained, memory_limit, y_output)
    [321](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:321)     raise ValueError(e_s)
    [323](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:323) if inbag is None:
--> [324](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:324)     inbag = calc_inbag(X_train_shape[0], forest)
    [326](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:326) pred_centered = _centered_prediction_forest(forest, X_test, y_output)
    [327](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:327) n_trees = forest.n_estimators

File ~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:70, in calc_inbag(n_samples, forest)
     [67](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:67)     raise ValueError(e_s)
     [69](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:69) n_trees = forest.n_estimators
---> [70](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:70) inbag = np.zeros((n_samples, n_trees))
     [71](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:71) sample_idx = []
     [72](https://file+.vscode-resource.vscode-cdn.net/Users/logan/Documents/VSCodeProjects/CS328-Final-Project/~/Documents/VSCodeProjects/CS328-Final-Project/.venv/lib/python3.12/site-packages/forestci/forestci.py:72) if isinstance(forest, BaseForest):

TypeError: only integer scalar arrays can be converted to a scalar index

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions