|
12 | 12 | from sklearn.utils.estimator_checks import check_estimator |
13 | 13 | from sklearn.utils.testing import (assert_equal, |
14 | 14 | assert_array_equal, |
| 15 | + assert_array_almost_equal, |
15 | 16 | assert_raises, |
16 | 17 | assert_in, |
17 | 18 | assert_not_in, |
@@ -452,6 +453,16 @@ def test_hdbscan_approximate_predict(): |
452 | 453 | cluster, prob = approximate_predict(clusterer, np.array([[0.0, 0.0]])) |
453 | 454 | assert_equal(cluster, -1) |
454 | 455 |
|
| 456 | +def test_hdbscan_membership_vector(): |
| 457 | + clusterer = HDBSCAN(prediction_data=True).fit(X) |
| 458 | + vector = membership_vector(clusterer, np.array([[-1.5, -1.0]])) |
| 459 | + assert_array_almost_equal(vector, np.array([[ 0.05705305, 0.05974177, 0.12228153]])) |
| 460 | + vector = membership_vector(clusterer, np.array([[1.5, -1.0]])) |
| 461 | + assert_array_almost_equal(vector, np.array([[ 0.09462176, 0.32061556, 0.10112905]])) |
| 462 | + vector = membership_vector(clusterer, np.array([[0.0, 0.0]])) |
| 463 | + assert_array_almost_equal(vector, np.array([[ 0.03545607, 0.03363318, 0.04643177]])) |
| 464 | + |
| 465 | + |
455 | 466 | def test_hdbscan_badargs(): |
456 | 467 | assert_raises(ValueError, |
457 | 468 | hdbscan, |
|
0 commit comments