Skip to content

Commit b177389

Browse files
committed
Fix codes for ham/spam messages
Newer versions of SciPy are not deterministic about which code goes to which cluster category. Uses argmin to get the index of the lowest indexed unique code. From the setup of the model, this is the ham category. Likewise, the highest indexed unique code is the spam category. The code for the unknown value is then the one value remaining that hasn't been used yet from 0, 1, or 2.
1 parent 878b161 commit b177389

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

python-scipy-cluster-optimize/cluster_sms_spam.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,25 @@
2020
codebook, _ = kmeans(whitened_counts, 3)
2121
codes, _ = vq(whitened_counts, codebook)
2222

23-
print("definitely spam:", unique_counts[codes == 0][-1])
24-
print("definitely ham:", unique_counts[codes == 1][-1])
25-
print("unknown:", unique_counts[codes == 2][-1])
23+
possible_codes = {0, 1, 2}
24+
unique_codes, code_indices = np.unique(codes, return_index=True)
25+
ham_code = unique_codes[np.argmin(code_indices)]
26+
spam_code = unique_codes[np.argmax(code_indices)]
27+
unknown_code = list(possible_codes ^ set((ham_code, spam_code)))[0]
28+
29+
print("definitely ham:", unique_counts[codes == ham_code][-1])
30+
print("definitely spam:", unique_counts[codes == spam_code][-1])
31+
print("unknown:", unique_counts[codes == unknown_code][-1])
2632

2733
digits = digit_counts[:, 1]
2834
predicted_hams = digits == 0
2935
predicted_spams = digits > 20
3036
predicted_unknowns = np.logical_and(digits > 0, digits <= 20)
3137

32-
spam_cluster = digit_counts[predicted_spams]
3338
ham_cluster = digit_counts[predicted_hams]
39+
spam_cluster = digit_counts[predicted_spams]
3440
unknown_cluster = digit_counts[predicted_unknowns]
3541

36-
print("definitely ham:", np.unique(ham_cluster[:, 0], return_counts=True))
37-
print("definitely spam:", np.unique(spam_cluster[:, 0], return_counts=True))
38-
print("unknown:", np.unique(unknown_cluster[:, 0], return_counts=True))
42+
print("hams:", np.unique(ham_cluster[:, 0], return_counts=True))
43+
print("spams:", np.unique(spam_cluster[:, 0], return_counts=True))
44+
print("unknowns:", np.unique(unknown_cluster[:, 0], return_counts=True))

0 commit comments

Comments
 (0)