Skip to content
This repository was archived by the owner on Jan 8, 2026. It is now read-only.

Commit 75bf3d2

Browse files
committed
Correct tests
1 parent 7793f24 commit 75bf3d2

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/clustering/test_models.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,30 +71,30 @@ def test_dp_normalize(dp_instance):
7171
assert np.isclose(normalized[1], 0.8)
7272

7373

74-
def test_dp_log_likelihood_base(dp_instance, embedding_fx):
74+
def test_dp_log_likelihood_vmf(dp_instance, embedding_fx):
7575
"""Test log likelihood calculation for a cluster."""
7676
cluster_id = 0
7777
dp_instance.cluster_params[cluster_id] = {"mean": embedding_fx, "count": 1}
78-
likelihood = dp_instance._log_likelihood_base(embedding_fx, cluster_id)
78+
likelihood = dp_instance._log_likelihood_vmf(embedding_fx, cluster_id)
7979

8080
assert isinstance(likelihood, float)
8181
assert likelihood > 0
8282

8383
orthogonal = np.array([0.0, 0.0, 0.0, 1.0])
84-
likelihood_orthogonal = dp_instance._log_likelihood_base(orthogonal, cluster_id)
84+
likelihood_orthogonal = dp_instance._log_likelihood_vmf(orthogonal, cluster_id)
8585

8686
assert likelihood_orthogonal < likelihood
8787

8888

89-
def test_dp_log_likelihood_base_nonexistent_cluster(dp_instance, embedding_fx):
89+
def test_dp_log_likelihood_vmf_nonexistent_cluster(dp_instance, embedding_fx):
9090
"""Test log likelihood for a nonexistent cluster."""
9191
dp_instance.global_mean = embedding_fx
92-
likelihood = dp_instance._log_likelihood_base(embedding_fx, 999)
92+
likelihood = dp_instance._log_likelihood_vmf(embedding_fx, 999)
9393

9494
assert isinstance(likelihood, float)
9595

9696
dp_instance.global_mean = None
97-
likelihood = dp_instance._log_likelihood_base(embedding_fx, 999)
97+
likelihood = dp_instance._log_likelihood_vmf(embedding_fx, 999)
9898
assert likelihood == 0.0
9999

100100

@@ -233,7 +233,7 @@ def test_dp_predict(dp_instance, embedding_fx):
233233

234234
with (
235235
patch.object(dp_instance, "get_embedding") as mock_embed,
236-
patch.object(dp_instance, "_log_likelihood_base") as mock_likelihood,
236+
patch.object(dp_instance, "_log_likelihood_vmf") as mock_likelihood,
237237
):
238238
mock_embed.return_value = np.array(
239239
[
@@ -410,7 +410,7 @@ def test_pyp_predict(pyp_instance, embedding_fx):
410410

411411
with (
412412
patch.object(pyp_instance, "get_embedding") as mock_embed,
413-
patch.object(pyp_instance, "_log_likelihood_base") as mock_likelihood,
413+
patch.object(pyp_instance, "_log_likelihood_vmf") as mock_likelihood,
414414
):
415415
mock_embed.return_value = np.array(
416416
[

0 commit comments

Comments
 (0)