Skip to content

Commit 1d1dc99

Browse files
authored
Add xtick rotation (#38)
* add xtick rotation * fix bug in np.logical_not
1 parent 477871d commit 1d1dc99

File tree

5 files changed

+16
-11
lines changed

5 files changed

+16
-11
lines changed

scikitplot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
2-
__version__ = '0.2.6'
2+
__version__ = '0.2.7'
33

44

55
from scikitplot.classifiers import classifier_factory

scikitplot/classifiers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def classifier_factory(clf):
5353
return clf
5454

5555

56-
def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, do_cv=True, cv=None,
57-
shuffle=True, random_state=None, ax=None, figsize=None,
56+
def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, x_tick_rotation=0,
57+
do_cv=True, cv=None, shuffle=True, random_state=None, ax=None, figsize=None,
5858
title_fontsize="large", text_fontsize="medium"):
5959
"""Generates the confusion matrix for a given classifier and dataset.
6060
@@ -79,6 +79,9 @@ def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, d
7979
normalize (bool, optional): If True, normalizes the confusion matrix before plotting.
8080
Defaults to False.
8181
82+
x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is
83+
useful in cases where there are numerous categories and the labels overlap each other.
84+
8285
do_cv (bool, optional): If True, the classifier is cross-validated on the dataset using the
8386
cross-validation strategy in `cv` to generate the confusion matrix. If False, the
8487
confusion matrix is generated without training or cross-validating the classifier.
@@ -158,7 +161,8 @@ def plot_confusion_matrix(clf, X, y, labels=None, title=None, normalize=False, d
158161
y_true = np.concatenate(trues_list)
159162

160163
ax = plotters.plot_confusion_matrix(y_true=y_true, y_pred=y_pred, labels=labels,
161-
title=title, normalize=normalize, ax=ax, figsize=figsize,
164+
title=title, normalize=normalize,
165+
x_tick_rotation=x_tick_rotation, ax=ax, figsize=figsize,
162166
title_fontsize=title_fontsize, text_fontsize=text_fontsize)
163167

164168
return ax

scikitplot/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def binary_ks_curve(y_true, y_probas):
4242
'{} category/ies'.format(len(lb.classes_)))
4343
idx = encoded_labels == 0
4444
data1 = np.sort(y_probas[idx])
45-
data2 = np.sort(y_probas[-idx])
45+
data2 = np.sort(y_probas[np.logical_not(idx)])
4646

4747
ctr1, ctr2 = 0, 0
4848
thresholds, pct1, pct2 = [], [], []

scikitplot/plotters.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from sklearn.metrics import silhouette_samples
2323

2424

25-
def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=False, ax=None,
26-
figsize=None, title_fontsize="large", text_fontsize="medium"):
25+
def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=False, x_tick_rotation=0,
26+
ax=None, figsize=None, title_fontsize="large", text_fontsize="medium"):
2727
"""Generates confusion matrix plot for a given set of ground truth labels and classifier predictions.
2828
2929
Args:
@@ -44,6 +44,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=Fal
4444
normalize (bool, optional): If True, normalizes the confusion matrix before plotting.
4545
Defaults to False.
4646
47+
x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is
48+
useful in cases where there are numerous categories and the labels overlap each other.
49+
4750
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to plot
4851
the learning curve. If None, the plot is drawn on a new set of axes.
4952
@@ -96,7 +99,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=Fal
9699
plt.colorbar(mappable=image)
97100
tick_marks = np.arange(len(classes))
98101
ax.set_xticks(tick_marks)
99-
ax.set_xticklabels(classes, fontsize=text_fontsize)
102+
ax.set_xticklabels(classes, fontsize=text_fontsize, rotation=x_tick_rotation)
100103
ax.set_yticks(tick_marks)
101104
ax.set_yticklabels(classes, fontsize=text_fontsize)
102105

setup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import os
77
import sys
88

9-
import scikitplot
10-
119
here = os.path.abspath(os.path.dirname(__file__))
1210

1311

@@ -37,7 +35,7 @@ def run_tests(self):
3735

3836
setup(
3937
name='scikit-plot',
40-
version=scikitplot.__version__,
38+
version='0.2.7',
4139
url='https://github.com/reiinakano/scikit-plot',
4240
license='MIT License',
4341
author='Reiichiro Nakano',

0 commit comments

Comments
 (0)