diff --git a/examples/plot_pruning.py b/examples/plot_pruning.py index 67f506c..eb011a0 100644 --- a/examples/plot_pruning.py +++ b/examples/plot_pruning.py @@ -27,10 +27,10 @@ from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression -iris = load_iris(as_frame=True) +iris = load_iris() baseline_lr = LogisticRegression(max_iter=1000).fit(iris["data"], iris["target"]) -X_train = iris["data"].values[10:120] -y_train = iris["target"].values[10:120] +X_train = iris["data"][10:120] +y_train = iris["target"][10:120] # %% # Random data pruning