@@ -570,7 +570,8 @@ One good method to keep in mind is Gaussian Naive Bayes
570570 >>> from sklearn.model_selection import train_test_split
571571
572572 >>> # split the data into training and validation sets
573- >>> X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)
573+ >>> X_train, X_test, y_train, y_test = train_test_split(
574+ ... digits.data, digits.target, random_state=42)
574575
575576 >>> # train the model
576577 >>> clf = GaussianNB()
@@ -581,9 +582,9 @@ One good method to keep in mind is Gaussian Naive Bayes
581582 >>> predicted = clf.predict(X_test)
582583 >>> expected = y_test
583584 >>> print(predicted)
584- [5 1 7 2 8 9 4 3 9 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 9 8 1 8 ...]
585+ [6 9 3 7 2 2 5 8 5 2 1 1 7 0 4 8 3 7 8 8 4 3 9 7 5 6 3 5 6 3 ...]
585586 >>> print(expected)
586- [5 8 7 2 8 9 4 3 7 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 3 3 1 8 ...]
587+ [6 9 3 7 2 1 5 2 5 2 1 9 4 0 4 2 3 7 8 8 4 3 9 7 5 6 3 5 6 3 ...]
587588
588589As above, we plot the digits with the predicted labels to get an idea of
589590how well the classification is working.
@@ -607,11 +608,11 @@ the number of matches::
607608
608609 >>> matches = (predicted == expected)
609610 >>> print(matches.sum())
610- 371
611+ 385
611612 >>> print(len(matches))
612613 450
613614 >>> matches.sum() / float(len(matches))
614- 0.82444 ...
615+ 0.8555 ...
615616
616617We see that more than 80% of the 450 predictions match the input. But
617618there are other more sophisticated metrics that can be used to judge the
@@ -625,20 +626,20 @@ combines several measures and prints a table with the results::
625626 >>> print(metrics.classification_report(expected, predicted))
626627 precision recall f1-score support
627628 <BLANKLINE>
628- 0 1.00 0.98 0.99 45
629- 1 0.91 0.66 0.76 44
630- 2 0.91 0.56 0.69 36
631- 3 0.89 0.67 0.77 49
632- 4 0.95 0.83 0.88 46
633- 5 0.93 0.93 0.93 45
634- 6 0.92 0.98 0.95 47
635- 7 0.75 0.96 0.84 50
636- 8 0.49 0.97 0.66 39
637- 9 0.85 0.67 0.75 49
629+ 0 1.00 0.95 0.98 43
630+ 1 0.85 0.78 0.82 37
631+ 2 0.85 0.61 0.71 38
632+ 3 0.97 0.83 0.89 46
633+ 4 0.98 0.84 0.90 55
634+ 5 0.90 0.95 0.93 59
635+ 6 0.90 0.96 0.92 45
636+ 7 0.71 0.98 0.82 41
637+ 8 0.60 0.89 0.72 38
638+ 9 0.90 0.73 0.80 48
638639 <BLANKLINE>
639- accuracy 0.82 450
640- macro avg 0.86 0.82 0.82 450
641- weighted avg 0.86 0.82 0.83 450
640+ accuracy 0.86 450
641+ macro avg 0.87 0.85 0.85 450
642+ weighted avg 0.88 0.86 0.86 450
642643 <BLANKLINE>
643644
644645
@@ -647,16 +648,16 @@ is a *confusion matrix*: it helps us visualize which labels are being
647648interchanged in the classification errors::
648649
649650 >>> print(metrics.confusion_matrix(expected, predicted))
650- [[44 0 0 0 0 0 0 0 0 1 ]
651- [ 0 29 0 0 0 0 1 6 6 2]
652- [ 0 1 20 1 0 0 0 0 14 0]
653- [ 0 0 0 33 0 2 0 1 11 2 ]
654- [ 0 0 0 0 38 1 2 4 1 0]
655- [ 0 0 0 0 0 42 1 0 2 0 ]
656- [ 0 0 0 0 0 0 46 0 1 0]
657- [ 0 0 0 0 1 0 0 48 0 1 ]
658- [ 0 1 0 0 0 0 0 0 38 0]
659- [ 0 1 2 3 1 0 0 5 4 33 ]]
651+ [[41 0 0 0 0 1 0 1 0 0 ]
652+ [ 0 29 2 0 0 0 0 0 4 2]
653+ [ 0 2 23 0 0 0 1 0 12 0]
654+ [ 0 0 1 38 0 1 0 0 5 1 ]
655+ [ 0 0 0 0 46 0 2 7 0 0]
656+ [ 0 0 0 0 0 56 1 1 0 1 ]
657+ [ 0 0 0 0 1 1 43 0 0 0]
658+ [ 0 0 0 0 0 1 0 40 0 0 ]
659+ [ 0 2 0 0 0 0 0 2 34 0]
660+ [ 0 1 1 1 0 2 1 5 2 35 ]]
660661
661662We see here that in particular, the numbers 1, 2, 3, and 9 are often
662663being labeled 8.
0 commit comments