|
17 | 17 | assert_not_in, |
18 | 18 | assert_no_warnings, |
19 | 19 | if_matplotlib) |
20 | | -from hdbscan import HDBSCAN, hdbscan, validity_index |
| 20 | +from hdbscan import (HDBSCAN, |
| 21 | + hdbscan, |
| 22 | + validity_index, |
| 23 | + approximate_predict, |
| 24 | + membership_vector, |
| 25 | + all_points_membership_vectors) |
21 | 26 | # from sklearn.cluster.tests.common import generate_clustered_data |
22 | 27 | from sklearn.datasets import make_blobs |
23 | 28 | from sklearn.utils import shuffle |
@@ -438,6 +443,14 @@ def test_hdbscan_min_span_tree_availability(): |
438 | 443 | tree = clusterer.minimum_spanning_tree_ |
439 | 444 | assert tree is None |
440 | 445 |
|
| 446 | +def test_hdbscan_approximate_predict(): |
| 447 | + clusterer = HDBSCAN(prediction_data=True).fit(X) |
| 448 | + cluster, prob = approximate_predict(clusterer, np.array([[-1.5, -1.0]])) |
| 449 | + assert_equal(cluster, 2) |
| 450 | + cluster, prob = approximate_predict(clusterer, np.array([[1.5, -1.0]])) |
| 451 | + assert_equal(cluster, 1) |
| 452 | + cluster, prob = approximate_predict(clusterer, np.array([[0.0, 0.0]])) |
| 453 | + assert_equal(cluster, -1) |
441 | 454 |
|
442 | 455 | def test_hdbscan_badargs(): |
443 | 456 | assert_raises(ValueError, |
|
0 commit comments