Skip to content

Commit 9450a91

Browse files
authored
V0.3.0 (#61)
* update docs for 0.3.0 * update README for 0.3.0 * add plot_calibration_curve
1 parent db3af7d commit 9450a91

16 files changed

+375
-75
lines changed

README.md

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
[![PyPI](https://img.shields.io/pypi/pyversions/scikit-plot.svg)]()
77
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.293191.svg)](https://doi.org/10.5281/zenodo.293191)
88

9-
### Scikit-learn with plotting.
9+
### Single line functions for detailed visualizations
1010

1111
### The quickest and easiest way to go from analysis...
1212

13-
![roc_curves](examples/readme_collage.jpg)
13+
![roc_curves](docs/_static/readme_collage.jpg)
1414

1515
### ...to this.
1616

@@ -24,58 +24,56 @@ That said, there are a number of visualizations that frequently pop up in machin
2424

2525
Say we use Naive Bayes in multi-class classification and decide we want to visualize the results of a common classification metric, the Area under the Receiver Operating Characteristic curve. Since the ROC is only valid in binary classification, we want to show the respective ROC of each class if it were the positive class. As an added bonus, let's show the micro-averaged and macro-averaged curve in the plot as well.
2626

27-
Using scikit-plot with the sample digits dataset from scikit-learn.
27+
Let's use scikit-plot with the sample digits dataset from scikit-learn.
2828

2929
```python
30-
from sklearn.datasets import load_digits as load_data
30+
# The usual train-test split mumbo-jumbo
31+
from sklearn.datasets import load_digits
32+
from sklearn.model_selection import train_test_split
3133
from sklearn.naive_bayes import GaussianNB
3234

33-
# This is all that's needed for scikit-plot
34-
import matplotlib.pyplot as plt
35-
from scikitplot import classifier_factory
36-
37-
X, y = load_data(return_X_y=True)
35+
X, y = load_digits(return_X_y=True)
36+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
3837
nb = GaussianNB()
39-
classifier_factory(nb)
40-
nb.plot_roc_curve(X, y, random_state=1)
38+
nb.fit(X_train, y_train)
39+
predicted_probas = nb.predict_proba(X_test)
40+
41+
# The magic happens here
42+
import matplotlib.pyplot as plt
43+
import scikitplot as skplt
44+
skplt.metrics.plot_roc_curve(y_test, predicted_probas)
4145
plt.show()
4246
```
4347
![roc_curves](examples/roc_curves.png)
4448

4549
Pretty.
4650

47-
So what happened here? First, we created a regular Naive Bayes classifier instance from scikit-learn and assigned it to `nb`. We then passed `nb` to `classifier_factory`. Then, like magic, we call `nb`'s *instance method* `plot_roc_curve` and pass it a features array and corresponding label array. Finally, we call `plt.show()` to display the corresponding plot.
48-
49-
Wait, what? The scikit-learn `GaussianNB` classifier doesn't have a `plot_roc_curve` method. How does this not throw an error? Well, `classifier_factory` is a function that modifies an __instance__ of a scikit-learn classifier. When we passed `nb` to `classifier_factory`, it __appended__ new plotting methods to the instance, one of which was `plot_roc_curve`, while leaving everything else alone.
50-
51-
This means that our classifier instance `nb` will behave the same way as before, with all its original variables and methods intact. In fact, if you take any of your existing scripts, pass your classifier instances to `classifier_factory` at the top and run them, you'll likely never notice a difference!
52-
53-
Classifiers aren't the only Scikit-learn objects. Scikit-plot offers a `clusterer_factory` function for generating common clustering plots. Visit the [docs](http://scikit-plot.readthedocs.io/en/latest/) for a complete list of what you can accomplish.
51+
And... That's it. Encaptured in that small example is the entire philosophy of Scikit-plot: **single line functions for detailed visualization**. You simply browse the plots available in the documentation, and call the function with the necessary arguments. Scikit-plot tries to stay out of your way as much as possible. No unnecessary bells and whistles. And when you *do* need the bells and whistles, each function offers a myriad of parameters for customizing various elements in your plots.
5452

5553
Finally, compare and [view the non-scikit-plot way of plotting the multi-class ROC curve](http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html). Which one would you rather do?
5654

5755
## Maximum flexibility. Compatibility with non-scikit-learn objects.
5856

59-
Although convenient, the Factory API may feel a little restrictive for more advanced users and users of external libraries. Thus, to offer more flexibility over your plotting, Scikit-plot also exposes a Functions API that, well, exposes functions.
57+
Although Scikit-plot is loosely based around the scikit-learn interface, you don't actually need Scikit-learn objects to use the available functions. As long as you provide the functions what they're asking for, they'll happily draw the plots for you.
6058

6159
Here's a quick example to generate the precision-recall curves of a Keras classifier on a sample dataset.
6260

6361
```python
6462
# Import what's needed for the Functions API
6563
import matplotlib.pyplot as plt
66-
import scikitplot.plotters as skplt
64+
import scikitplot as skplt
6765

6866
# This is a Keras classifier. We'll generate probabilities on the test set.
6967
keras_clf.fit(X_train, y_train, batch_size=64, nb_epoch=10, verbose=2)
7068
probas = keras_clf.predict_proba(X_test, batch_size=64)
7169

7270
# Now plot.
73-
skplt.plot_precision_recall_curve(y_test, probas)
71+
skplt.metrics.plot_precision_recall_curve(y_test, probas)
7472
plt.show()
7573
```
7674
![p_r_curves](examples/p_r_curves.png)
7775

78-
You can see clearly here that `skplt.plot_precision_recall_curve` needs only the ground truth y-values and the predicted probabilities to generate the plot. This lets you use *anything* you want as the classifier, from Keras NNs to NLTK Naive Bayes to that groundbreaking classifier algorithm you just wrote.
76+
You can see clearly here that `skplt.metrics.plot_precision_recall_curve` needs only the ground truth y-values and the predicted probabilities to generate the plot. This lets you use *anything* you want as the classifier, from Keras NNs to NLTK Naive Bayes to that groundbreaking classifier algorithm you just wrote.
7977

8078
The possibilities are endless.
8179

@@ -88,7 +86,7 @@ Then just run:
8886
pip install scikit-plot
8987
```
9088

91-
Or if you want, clone this repo and run
89+
Or if you want the latest development version, clone this repo and run
9290
```bash
9391
python setup.py install
9492
```

docs/Quickstart.rst

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Before we begin plotting, we'll need to import the following for Scikit-plot::
3232

3333
>>> import matplotlib.pyplot as plt
3434

35-
:class:`matplotlib.pyplot` is used by Matplotlib to make plotting work like it does in MATLAB and deals with things like axes, figures, and subplots. But don't worry. Unless you're an advanced user, you won't need to understand any of that while using Scikit-plot. All you need to remember is that we use the ``matplotlib.pyplot.show()`` function to show any plots generated by Scikit-plot.
35+
:mod:`matplotlib.pyplot` is used by Matplotlib to make plotting work like it does in MATLAB and deals with things like axes, figures, and subplots. But don't worry. Unless you're an advanced user, you won't need to understand any of that while using Scikit-plot. All you need to remember is that we use the :func:`matplotlib.pyplot.show` function to show any plots generated by Scikit-plot.
3636

3737
Let's begin by generating our sample digits dataset::
3838

@@ -46,24 +46,17 @@ We'll proceed by creating an instance of a RandomForestClassifier object from Sc
4646
>>> from sklearn.ensemble import RandomForestClassifier
4747
>>> random_forest_clf = RandomForestClassifier(n_estimators=5, max_depth=5, random_state=1)
4848

49-
The magic happens in the next two lines::
49+
Let's use :func:`sklearn.model_selection.cross_val_predict` to generate predicted labels on our dataset::
5050

51-
>>> from scikitplot import classifier_factory
52-
>>> classifier_factory(random_forest_clf)
53-
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
54-
max_depth=5, max_features='auto', max_leaf_nodes=None,
55-
min_impurity_split=1e-07, min_samples_leaf=1,
56-
min_samples_split=2, min_weight_fraction_leaf=0.0,
57-
n_estimators=5, n_jobs=1, oob_score=False, random_state=1,
58-
verbose=0, warm_start=False)
51+
>>> from sklearn.model_selection import cross_val_predict
52+
>>> predictions = cross_val_predict(random_forest_clf, X, y)
5953

60-
In detail, here's what happened. :func:`~scikitplot.classifier_factory` is a function that modifies an instance of a scikit-learn classifier. When we passed ``random_forest_clf`` to :func:`~scikitplot.classifier_factory`, it **appended** new plotting methods to the instance, while leaving everything else alone. The original variables and methods of ``random_forest_clf`` are kept intact. In fact, if you take any of your existing scripts, pass your classifier instances to :func:`~scikitplot.classifier_factory` at the top and run them, you'll likely never notice a difference! (If something does break, though, we'd appreciate it if you open an issue at Scikit-plot's `Github repository <https://github.com/reiinakano/scikit-plot>`_.)
54+
For those not familiar with what :func:`cross_val_predict` does, it generates cross-validated estimates for each sample point in our dataset. Comparing the cross-validated estimates with the true labels, we'll be able to get evaluation metrics such as accuracy, precision, recall, and in our case, the confusion matrix.
6155

62-
Among the methods added to our classifier instance is the :func:`~scikitplot.classifiers.plot_confusion_matrix` method, used to generate a colored heatmap of the classifier's confusion matrix as evaluated on a dataset.
56+
To plot and show our confusion matrix, we'll use the function :func:`~scikitplot.metrics.plot_confusion_matrix`, passing it both the true labels and predicted labels. We'll also set the optional argument ``normalize=True`` so the values displayed in our confusion matrix plot will be from the range [0, 1]. Finally, to show our plot, we'll call ``plt.show()``.
6357

64-
To plot and show how well our classifier does on the sample dataset, we'll run ``random_forest_clf``'s new instance method :func:`~scikitplot.classifiers.plot_confusion_matrix`, passing it the features and labels of our sample dataset. We'll also pass ``normalize=True`` to :func:`~scikitplot.classifiers.plot_confusion_matrix` so the values displayed in our confusion matrix plot will be from the range [0, 1]. Finally, to show our plot, we'll call ``plt.show()``.
65-
66-
>>> random_forest_clf.plot_confusion_matrix(X, y, normalize=True)
58+
>>> import scikitplot as skplt
59+
>>> skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True)
6760
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
6861
>>> plt.show()
6962

@@ -73,39 +66,33 @@ To plot and show how well our classifier does on the sample dataset, we'll run `
7366

7467
And that's it! A quick glance of our confusion matrix shows that our classifier isn't doing so well with identifying the digits 1, 8, and 9. Hmm. Perhaps a bit more tweaking of our Random Forest's hyperparameters is in order.
7568

76-
.. admonition:: Note
77-
78-
The more observant of you will notice that we didn't train our classifier at all. Exactly how was the confusion matrix generated? Well, :func:`~scikitplot.classifiers.plot_confusion_matrix` provides an optional parameter ``do_cv``, set to **True** by default, that determines whether or not the classifier will use cross-validation to generate the confusion matrix. If **True**, the predictions generated by each iteration in the cross-validation are aggregated and used to generate the confusion matrix.
79-
80-
If you do not wish to do cross-validation e.g. you have separate training and testing datasets, simply set ``do_cv`` to **False** and make sure the classifier is already trained prior to calling :func:`~scikitplot.classifiers.plot_confusion_matrix`. In this case, the confusion matrix will be generated on the predictions of the trained classifier on the passed ``X`` and ``y``.
69+
One more example
70+
----------------
8171

82-
The Functions API
83-
-----------------
84-
85-
Although convenient, the Factory API may feel a little restrictive for more advanced users and users of external libraries. Thus, to offer more flexibility over your plotting, Scikit-plot also exposes a Functions API that, well, exposes functions.
86-
87-
The nature of the Functions API offers compatibility with non-scikit-learn objects.
72+
Finally, let's show an example wherein we *don't* use Scikit-learn.
8873

8974
Here's a quick example to generate the precision-recall curves of a Keras classifier on a sample dataset.
9075

9176
>>> # Import what's needed for the Functions API
9277
>>> import matplotlib.pyplot as plt
93-
>>> import scikitplot.plotters as skplt
78+
>>> import scikitplot as skplt
9479
>>> # This is a Keras classifier. We'll generate probabilities on the test set.
9580
>>> keras_clf.fit(X_train, y_train, batch_size=64, nb_epoch=10, verbose=2)
9681
>>> probas = keras_clf.predict_proba(X_test, batch_size=64)
9782
>>> # Now plot.
98-
>>> skplt.plot_precision_recall_curve(y_test, probas)
83+
>>> skplt.metrics.plot_precision_recall_curve(y_test, probas)
9984
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
10085
>>> plt.show()
10186

10287
.. image:: _static/quickstart_plot_precision_recall_curve.png
10388
:align: center
10489
:alt: Precision Recall Curves
10590

106-
And again, that's it! You'll notice that in this plot, all we needed to do was pass the ground truth labels and predicted probabilities to :func:`~scikitplot.plotters.plot_precision_recall_curve` to generate the precision-recall curves. This means you can use literally any classifier you want to generate the precision-recall curves, from Keras classifiers to NLTK Naive Bayes to XGBoost, as long as you pass in the predicted probabilities in the correct format.
91+
And again, that's it! As in the example above, all we needed to do was pass the ground truth labels and predicted probabilities to :func:`~scikitplot.metrics.plot_precision_recall_curve` to generate the precision-recall curves. This means you can use literally any classifier you want to generate the precision-recall curves, from Keras classifiers to NLTK Naive Bayes to XGBoost, as long as you pass in the predicted probabilities in the correct format.
92+
93+
Now what?
94+
---------
10795

108-
More Plots
109-
----------
96+
The recommended way to start using Scikit-plot is to just go through the documentation for the various modules and choose which plots you think would be useful for your work.
11097

111-
Want to know the other plots you can generate using Scikit-plot? Visit the :ref:`factoryapidocs` or the :ref:`functionsapidocs`.
98+
Happy plotting!
59.4 KB
Loading
21.2 KB
Loading

docs/cluster.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. apidocs file containing the API Documentation
2+
.. _clusterdocs:
3+
4+
Clusterer Module (API Reference)
5+
================================
6+
7+
.. automodule:: scikitplot.cluster
8+
:members: plot_elbow_curve

docs/decomposition.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. apidocs file containing the API Documentation
2+
.. _decompositiondocs:
3+
4+
Decomposition Module (API Reference)
5+
====================================
6+
7+
.. automodule:: scikitplot.decomposition
8+
:members: plot_pca_component_variance, plot_pca_2d_projection

docs/estimators.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. apidocs file containing the API Documentation
2+
.. _estimatorssdocs:
3+
4+
Estimators Module (API Reference)
5+
=================================
6+
7+
.. automodule:: scikitplot.estimators
8+
:members: plot_learning_curve, plot_feature_importances

docs/index.rst

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,28 @@
66
Welcome to Scikit-plot's documentation!
77
=======================================
88

9+
The quickest and easiest way to go from analysis...
10+
---------------------------------------------------
11+
12+
.. image:: _static/readme_collage.jpg
13+
:align: center
14+
:alt: All plots
15+
16+
...to this.
17+
-----------
18+
19+
Scikit-plot is the result of an unartistic data scientist's dreadful realization that *visualization is one of the most crucial components in the data science process, not just a mere afterthought*.
20+
21+
Gaining insights is simply a lot easier when you're looking at a colored heatmap of a confusion matrix complete with class labels rather than a single-line dump of numbers enclosed in brackets. Besides, if you ever need to present your results to someone (virtually any time anybody hires you to do data science), you show them visualizations, not a bunch of numbers in Excel.
22+
23+
That said, there are a number of visualizations that frequently pop up in machine learning. Scikit-plot is a humble attempt to provide aesthetically-challenged programmers (such as myself) the opportunity to generate quick and beautiful graphs and plots with as little boilerplate as possible.
24+
925
.. toctree::
1026
:maxdepth: 2
1127
:name: mastertoc
1228

1329
First Steps with Scikit-plot <Quickstart>
14-
Factory API Reference <apidocs>
15-
Functions API Reference <functionsapidocs>
16-
17-
18-
Indices and tables
19-
==================
20-
21-
* :ref:`genindex`
22-
* :ref:`modindex`
23-
* :ref:`search`
30+
Metrics Module <metrics>
31+
Estimators Module <estimators>
32+
Clusterer Module <cluster>
33+
Decomposition Module <decomposition>

docs/metrics.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. apidocs file containing the API Documentation
2+
.. _metricsdocs:
3+
4+
Metrics Module (API Reference)
5+
==============================
6+
7+
.. automodule:: scikitplot.metrics
8+
:members: plot_confusion_matrix, plot_roc_curve, plot_ks_statistic, plot_precision_recall_curve, plot_silhouette, plot_calibration_curve

0 commit comments

Comments
 (0)