Skip to content
This repository was archived by the owner on Jan 8, 2026. It is now read-only.

Commit ceb383f

Browse files
committed
Add typing, and minor tweaks
1 parent 5a0b5d1 commit ceb383f

File tree

7 files changed

+41
-32
lines changed

7 files changed

+41
-32
lines changed

clusterium.code-workspace

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@
66
],
77
"settings": {
88
"git.enableSmartCommit": true,
9-
109
"makefile.configureOnOpen": false,
11-
1210
"python.analysis.typeCheckingMode": "basic",
1311
"python.testing.pytestArgs": [
1412
"tests"
1513
],
1614
"python.testing.unittestEnabled": false,
1715
"python.testing.pytestEnabled": true,
18-
1916
"[python]": {
2017
"editor.defaultFormatter": "ms-python.black-formatter",
2118
"editor.formatOnSave": true,
@@ -31,7 +28,6 @@
3128
"--python-version",
3229
"auto"
3330
],
34-
3531
"black-formatter.args": [
3632
"--extend-exclude",
3733
".poetry",
@@ -45,8 +41,11 @@
4541
},
4642
"isort.check": true,
4743
"isort.importStrategy": "fromEnvironment",
48-
49-
"flake8.path": ["${interpreter}", "-m", "flake8"],
44+
"flake8.path": [
45+
"${interpreter}",
46+
"-m",
47+
"flake8"
48+
],
5049
"flake8.args": [
5150
"--max-line-length",
5251
"88",

clusx/clustering/models.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,17 @@
3232
from scipy.special import logsumexp
3333
from sentence_transformers import SentenceTransformer
3434

35+
from clusx.logging import get_logger
36+
from clusx.utils import to_numpy
37+
3538
if TYPE_CHECKING:
36-
from typing import Any, Optional, Union
39+
from typing import Optional, Union
3740

3841
import torch
3942
from numpy.typing import NDArray
4043

4144
EmbeddingTensor = Union[torch.Tensor, NDArray[np.float32]]
4245

43-
from clusx.logging import get_logger
44-
from clusx.utils import to_numpy
45-
4646
logger = get_logger(__name__)
4747

4848

@@ -122,8 +122,8 @@ def __init__(
122122
# For reproducibility
123123
self.random_state = np.random.default_rng(seed=random_state)
124124

125-
self.clusters = []
126-
self.cluster_params = {}
125+
self.clusters: list[int] = []
126+
self.cluster_params: dict[int, dict[str, EmbeddingTensor | int]] = {}
127127
self.global_mean: Optional[EmbeddingTensor] = None
128128
self.next_id = 0
129129
self.embeddings_: Optional[EmbeddingTensor] = None
@@ -374,7 +374,7 @@ def _calculate_cluster_probabilities(
374374

375375
# Convert log scores to probabilities
376376
scores = np.array(scores)
377-
scores -= logsumexp(scores)
377+
scores -= logsumexp(scores) # type: ignore
378378
probabilities = np.exp(scores) # type: np.ndarray
379379

380380
# Add placeholder for new cluster ID
@@ -478,7 +478,7 @@ def assign_cluster(self, embedding: EmbeddingTensor) -> tuple[int, np.ndarray]:
478478

479479
return cluster_id, probs
480480

481-
def fit(self, documents, _y: Union[Any, None] = None):
481+
def fit(self, documents, _y=None):
482482
"""
483483
Train the clustering model on the given text data.
484484
@@ -490,7 +490,7 @@ def fit(self, documents, _y: Union[Any, None] = None):
490490
----------
491491
documents : Union[list[str], list[EmbeddingTensor]]
492492
The text documents or embeddings to cluster.
493-
_y : Union[Any, None]
493+
_y
494494
Ignored. Added for compatibility with scikit-learn API.
495495
496496
Returns
@@ -580,15 +580,15 @@ def predict(self, documents):
580580

581581
return np.array(predictions)
582582

583-
def fit_predict(self, documents, _y: Union[Any, None] = None):
583+
def fit_predict(self, documents, _y=None):
584584
"""
585585
Fit the model and predict cluster labels for documents.
586586
587587
Parameters
588588
----------
589589
documents : Union[list[str], list[EmbeddingTensor]]
590590
The text documents or embeddings to cluster.
591-
_y : Union[Any, None]
591+
_y
592592
This parameter exists only for compatibility with scikit-learn API.
593593
594594
Returns
@@ -837,7 +837,7 @@ def _calculate_cluster_probabilities(
837837

838838
# Convert log scores to probabilities
839839
scores = np.array(scores)
840-
scores -= logsumexp(scores)
840+
scores -= logsumexp(scores) # type: ignore
841841
probabilities = np.exp(scores)
842842

843843
# Add placeholder for new cluster ID

clusx/clustering/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def save_clusters_to_json(
161161
"""
162162
# Group texts by cluster
163163
cluster_groups = defaultdict(list)
164-
for text, cluster_id in zip(texts or [], clusters or []):
164+
for text, cluster_id in zip(texts, clusters):
165165
cluster_groups[cluster_id].append(text)
166166

167167
clusters_json = {

clusx/evaluation.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,17 @@
3232
from typing import TYPE_CHECKING
3333

3434
import numpy as np
35-
from sklearn.metrics import silhouette_score
36-
from sklearn.metrics.pairwise import cosine_similarity
37-
from sklearn.neighbors import NearestNeighbors
35+
from sklearn.metrics import silhouette_score # type: ignore
36+
from sklearn.metrics.pairwise import cosine_similarity # type: ignore
37+
from sklearn.neighbors import NearestNeighbors # type: ignore
38+
39+
from clusx.errors import EvaluationError
40+
from clusx.logging import get_logger
3841

3942
if TYPE_CHECKING:
40-
import numpy # pylint: disable=reimported
4143
from typing import Any, Union
4244

43-
from .errors import EvaluationError
44-
from .logging import get_logger
45+
import numpy # pylint: disable=reimported
4546

4647
logger = get_logger(__name__)
4748

@@ -152,9 +153,10 @@ def __init__(
152153

153154
# Validate inputs
154155
if len(texts) != len(embeddings) or len(texts) != len(cluster_assignments):
155-
raise ValueError(
156+
raise EvaluationError(
156157
"Length mismatch: texts, embeddings, and cluster_assignments "
157-
"must have the same length"
158+
f"must have the same length, got {len(texts)}, {len(embeddings)}, "
159+
f"and {len(cluster_assignments)} respectively",
158160
)
159161

160162
logger.info(
@@ -191,7 +193,7 @@ def calculate_silhouette_score(self) -> float:
191193
is not possible
192194
"""
193195
# Count samples per cluster
194-
cluster_counts = {}
196+
cluster_counts: dict[int, int] = {}
195197
for cluster_id in self.cluster_assignments:
196198
cluster_counts[cluster_id] = cluster_counts.get(cluster_id, 0) + 1
197199

@@ -294,6 +296,7 @@ def calculate_similarity_metrics(
294296

295297
# Calculate intra-cluster similarities
296298
intra_sims = []
299+
# TODO: On a file of 170000 lines at this point we die.
297300
for cluster_indices in valid_clusters.values():
298301
cluster_embeddings = self.embeddings[cluster_indices]
299302
sim_matrix = cosine_similarity(cluster_embeddings)
@@ -379,7 +382,7 @@ def detect_powerlaw_distribution(self) -> dict[str, Any]:
379382
}
380383

381384
try:
382-
import powerlaw
385+
import powerlaw # type: ignore
383386

384387
# 1. Get cluster sizes
385388
cluster_sizes = []

clusx/visualization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
import numpy as np
1616
from matplotlib import colormaps
1717

18+
from clusx.errors import VisualizationError
19+
from clusx.logging import get_logger
20+
1821
if TYPE_CHECKING:
1922
from typing import Any
23+
2024
from matplotlib.axes import Axes
2125

22-
from .errors import VisualizationError
23-
from .logging import get_logger
2426

2527
logger = get_logger(__name__)
2628

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
# Add any paths that contain templates here, relative to this directory.
3333
templates_path = ["_templates"]
34-
exclude_patterns = []
34+
exclude_patterns: list[str] = []
3535

3636
# -- Options for nitpick -----------------------------------------------------
3737

tests/load/profile-4.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
--dp-alpha 10 --dp-kappa 5 --pyp-alpha 5 --pyp-kappa 10 --pyp-sigma 0.5
2+
--dp-alpha 15 --dp-kappa 10 --pyp-alpha 12 --pyp-kappa 10 --pyp-sigma 0.7
3+
--dp-alpha 25 --dp-kappa 8 --pyp-alpha 20 --pyp-kappa 8 --pyp-sigma 0.8
4+
--dp-alpha 18 --dp-kappa 12 --pyp-alpha 15 --pyp-kappa 12 --pyp-sigma 0.7
5+
--dp-alpha 30 --dp-kappa 6 --pyp-alpha 20 --pyp-kappa 6 --pyp-sigma 0.9

0 commit comments

Comments
 (0)