Skip to content
This repository was archived by the owner on Aug 9, 2023. It is now read-only.

Commit e81be34

Browse files
author
Campbells
authored
Merge pull request #276 from wellcometrust/feature/clustering-improvements
Feature/clustering improvements
2 parents b5ef951 + 4003a70 commit e81be34

File tree

2 files changed

+104
-7
lines changed

2 files changed

+104
-7
lines changed

tests/test_clustering.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
@pytest.mark.parametrize("reducer,cluster_reduced", [("tsne", True),
77
("umap", True),
88
("umap", False)])
9-
def test_full_pipeline(reducer, cluster_reduced):
9+
def test_full_pipeline(reducer, cluster_reduced, tmp_path):
1010
cluster = TextClustering(reducer=reducer, cluster_reduced=cluster_reduced,
1111
embedding_random_state=42,
1212
reducer_random_state=43,
@@ -23,6 +23,17 @@ def test_full_pipeline(reducer, cluster_reduced):
2323

2424
assert len(cluster.cluster_kws) == len(cluster.cluster_ids) == 6
2525

26+
cluster.save(folder=tmp_path)
27+
28+
cluster_new = TextClustering()
29+
cluster_new.load(folder=tmp_path)
30+
31+
# Asserts all coordinates of the loaded points are equal
32+
assert (cluster_new.embedded_points != cluster.embedded_points).sum() == 0
33+
assert (cluster_new.reduced_points != cluster.reduced_points).sum() == 0
34+
assert cluster_new.reducer_class.__class__ == cluster.reducer_class.__class__
35+
assert cluster_new.clustering_class.__class__ == cluster.clustering_class.__class__
36+
2637

2738
@pytest.mark.parametrize("reducer", ["tsne", "umap"])
2839
def test_parameter_search(reducer):

wellcomeml/ml/clustering.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import defaultdict
22
import logging
33
import os
4+
import pickle
45

56
from wellcomeml.ml import vectorizer
67
from wellcomeml.logger import logger
@@ -39,6 +40,7 @@ class TextClustering(object):
3940
cluster_names: Names of the clusters
4041
cluster_kws: Keywords for the clusters (only if embedding=tf-idf)
4142
"""
43+
4244
def __init__(self, embedding='tf-idf', reducer='umap', clustering='dbscan',
4345
cluster_reduced=True, n_kw=10, params={},
4446
embedding_random_state=None, reducer_random_state=None,
@@ -113,13 +115,21 @@ class that is a sklearn.base.ClusterMixin
113115
'random_state'):
114116
self.clustering_class.random_state = clustering_random_state
115117

118+
self.embedded_points = None
119+
self.reduced_points = None
116120
self.cluster_ids = None
117121
self.cluster_names = None
118122
self.cluster_kws = None
119123
self.kw_dictionary = {}
120124
self.silhouette = None
121125
self.optimise_results = {}
122126

127+
self.embedded_points_filename = 'embedded_points.npy'
128+
self.reduced_points_filename = 'reduced_points.npy'
129+
self.vectorizer_filename = 'vectorizer.pkl'
130+
self.reducer_filename = 'reducer.pkl'
131+
self.clustering_filename = 'clustering.pkl'
132+
123133
def fit(self, X, *_):
124134
"""
125135
Fits all clusters in the pipeline
@@ -131,22 +141,28 @@ def fit(self, X, *_):
131141
A TextClustering object
132142
133143
"""
134-
self._fit_step(X, step='vectorizer')
135-
self._fit_step(step='reducer')
136-
self._fit_step(step='clustering')
144+
self.fit_step(X, step='vectorizer')
145+
self.fit_step(step='reducer')
146+
self.fit_step(step='clustering')
137147

138148
if self.embedding == 'tf-idf' and self.n_kw:
139149
self._find_keywords(self.embedded_points.toarray(), n_kw=self.n_kw)
140150

141151
return self
142152

143-
def _fit_step(self, X=None, step='vectorizer'):
153+
def fit_step(self, X=None, y=None, step='vectorizer'):
144154
"""Internal function for partial fitting only a certain step"""
145155
if step == 'vectorizer':
146156
self.embedded_points = self.vectorizer.fit_transform(X)
147157
elif step == 'reducer':
148-
self.reduced_points = \
149-
self.reducer_class.fit_transform(self.embedded_points)
158+
if self.embedded_points is None:
159+
raise ValueError(
160+
'You must embed/vectorise the points before reducing dimensionality'
161+
)
162+
if X is None:
163+
X = self.embedded_points
164+
165+
self.reduced_points = self.reducer_class.fit_transform(X=X, y=y)
150166
elif step == 'clustering':
151167
points = (
152168
self.reduced_points if self.cluster_reduced else
@@ -260,7 +276,9 @@ def optimise(self, X, param_grid, n_cluster_range=None, max_noise=0.2,
260276
# Prunes result to actually optimise under constraints
261277
best_silhouette = 0
262278
best_params = {}
279+
263280
grid.fit(X, y=None)
281+
264282
for params, silhouette, noise, n_clusters in zip(
265283
grid.cv_results_['params'],
266284
grid.cv_results_['mean_test_silhouette'],
@@ -292,6 +310,74 @@ def optimise(self, X, param_grid, n_cluster_range=None, max_noise=0.2,
292310

293311
return best_params
294312

313+
def save(self, folder, components='all', create_folder=True):
314+
"""
315+
Saves the different steps of the pipeline
316+
317+
Args:
318+
folder(str): path to folder
319+
components(list or 'all'): List of components to save. Options are: 'embbedded_points',
320+
'reduced_points', 'vectorizer', 'reducer', and 'clustering_model'. By default, loads
321+
'all' (you can get all components by listing the class param
322+
TextClustering.components)
323+
324+
"""
325+
if create_folder:
326+
os.makedirs(folder, exist_ok=True)
327+
328+
if components == 'all' or 'embedded_points' in components:
329+
np.save(os.path.join(folder, self.embedded_points_filename), self.embedded_points)
330+
331+
if components == 'all' or 'reduced_points' in components:
332+
np.save(os.path.join(folder, self.reduced_points_filename), self.reduced_points)
333+
334+
if components == 'all' or 'vectorizer' in components:
335+
with open(os.path.join(folder, self.vectorizer_filename), 'wb') as f:
336+
pickle.dump(self.vectorizer, f)
337+
338+
if components == 'all' or 'reducer' in components:
339+
with open(os.path.join(folder, self.reducer_filename), 'wb') as f:
340+
pickle.dump(self.reducer_class, f)
341+
342+
if components == 'all' or 'clustering_model' in components:
343+
with open(os.path.join(folder, self.clustering_filename), 'wb') as f:
344+
pickle.dump(self.clustering_class, f)
345+
346+
def load(self, folder, components='all'):
347+
"""
348+
Loads the different steps of the pipeline
349+
350+
Args:
351+
folder(str): path to folder
352+
components(list or 'all'): List of components to load. Options are: 'embbedded_points',
353+
'reduced_points', 'vectorizer', 'reducer', and 'clustering_model'. By default, loads
354+
'all' (you can get all components by listing the class param
355+
TextClustering.components)
356+
357+
"""
358+
359+
if components == 'all' or 'embedded_points' in components:
360+
self.embedded_points = np.load(os.path.join(folder, self.embedded_points_filename),
361+
allow_pickle=True)
362+
if not self.embedded_points.shape:
363+
self.embedded_points = self.embedded_points[()]
364+
365+
if components == 'all' or 'reduced_points' in components:
366+
self.reduced_points = np.load(os.path.join(folder, self.reduced_points_filename),
367+
allow_pickle=True)
368+
369+
if components == 'all' or 'vectorizer' in components:
370+
with open(os.path.join(folder, self.vectorizer_filename), 'rb') as f:
371+
self.vectorizer = pickle.load(f)
372+
373+
if components == 'all' or 'reducer' in components:
374+
with open(os.path.join(folder, self.reducer_filename), 'rb') as f:
375+
self.reducer_class = pickle.load(f)
376+
377+
if components == 'all' or 'clustering_model' in components:
378+
with open(os.path.join(folder, self.clustering_filename), 'rb') as f:
379+
self.clustering_class = pickle.load(f)
380+
295381
def stability(self):
296382
"""Function to calculate how stable the clusters are"""
297383
raise NotImplementedError

0 commit comments

Comments
 (0)