11from collections import defaultdict
22import logging
33import os
4+ import pickle
45
56from wellcomeml .ml import vectorizer
67from wellcomeml .logger import logger
@@ -39,6 +40,7 @@ class TextClustering(object):
3940 cluster_names: Names of the clusters
4041 cluster_kws: Keywords for the clusters (only if embedding=tf-idf)
4142 """
43+
4244 def __init__ (self , embedding = 'tf-idf' , reducer = 'umap' , clustering = 'dbscan' ,
4345 cluster_reduced = True , n_kw = 10 , params = {},
4446 embedding_random_state = None , reducer_random_state = None ,
@@ -113,13 +115,21 @@ class that is a sklearn.base.ClusterMixin
113115 'random_state' ):
114116 self .clustering_class .random_state = clustering_random_state
115117
118+ self .embedded_points = None
119+ self .reduced_points = None
116120 self .cluster_ids = None
117121 self .cluster_names = None
118122 self .cluster_kws = None
119123 self .kw_dictionary = {}
120124 self .silhouette = None
121125 self .optimise_results = {}
122126
127+ self .embedded_points_filename = 'embedded_points.npy'
128+ self .reduced_points_filename = 'reduced_points.npy'
129+ self .vectorizer_filename = 'vectorizer.pkl'
130+ self .reducer_filename = 'reducer.pkl'
131+ self .clustering_filename = 'clustering.pkl'
132+
123133 def fit (self , X , * _ ):
124134 """
125135 Fits all clusters in the pipeline
@@ -131,22 +141,28 @@ def fit(self, X, *_):
131141 A TextClustering object
132142
133143 """
134- self ._fit_step (X , step = 'vectorizer' )
135- self ._fit_step (step = 'reducer' )
136- self ._fit_step (step = 'clustering' )
144+ self .fit_step (X , step = 'vectorizer' )
145+ self .fit_step (step = 'reducer' )
146+ self .fit_step (step = 'clustering' )
137147
138148 if self .embedding == 'tf-idf' and self .n_kw :
139149 self ._find_keywords (self .embedded_points .toarray (), n_kw = self .n_kw )
140150
141151 return self
142152
143- def _fit_step (self , X = None , step = 'vectorizer' ):
153+ def fit_step (self , X = None , y = None , step = 'vectorizer' ):
144154 """Internal function for partial fitting only a certain step"""
145155 if step == 'vectorizer' :
146156 self .embedded_points = self .vectorizer .fit_transform (X )
147157 elif step == 'reducer' :
148- self .reduced_points = \
149- self .reducer_class .fit_transform (self .embedded_points )
158+ if self .embedded_points is None :
159+ raise ValueError (
160+ 'You must embed/vectorise the points before reducing dimensionality'
161+ )
162+ if X is None :
163+ X = self .embedded_points
164+
165+ self .reduced_points = self .reducer_class .fit_transform (X = X , y = y )
150166 elif step == 'clustering' :
151167 points = (
152168 self .reduced_points if self .cluster_reduced else
@@ -260,7 +276,9 @@ def optimise(self, X, param_grid, n_cluster_range=None, max_noise=0.2,
260276 # Prunes result to actually optimise under constraints
261277 best_silhouette = 0
262278 best_params = {}
279+
263280 grid .fit (X , y = None )
281+
264282 for params , silhouette , noise , n_clusters in zip (
265283 grid .cv_results_ ['params' ],
266284 grid .cv_results_ ['mean_test_silhouette' ],
@@ -292,6 +310,74 @@ def optimise(self, X, param_grid, n_cluster_range=None, max_noise=0.2,
292310
293311 return best_params
294312
313+ def save (self , folder , components = 'all' , create_folder = True ):
314+ """
315+ Saves the different steps of the pipeline
316+
317+ Args:
318+ folder(str): path to folder
319+ components(list or 'all'): List of components to save. Options are: 'embbedded_points',
320+ 'reduced_points', 'vectorizer', 'reducer', and 'clustering_model'. By default, loads
321+ 'all' (you can get all components by listing the class param
322+ TextClustering.components)
323+
324+ """
325+ if create_folder :
326+ os .makedirs (folder , exist_ok = True )
327+
328+ if components == 'all' or 'embedded_points' in components :
329+ np .save (os .path .join (folder , self .embedded_points_filename ), self .embedded_points )
330+
331+ if components == 'all' or 'reduced_points' in components :
332+ np .save (os .path .join (folder , self .reduced_points_filename ), self .reduced_points )
333+
334+ if components == 'all' or 'vectorizer' in components :
335+ with open (os .path .join (folder , self .vectorizer_filename ), 'wb' ) as f :
336+ pickle .dump (self .vectorizer , f )
337+
338+ if components == 'all' or 'reducer' in components :
339+ with open (os .path .join (folder , self .reducer_filename ), 'wb' ) as f :
340+ pickle .dump (self .reducer_class , f )
341+
342+ if components == 'all' or 'clustering_model' in components :
343+ with open (os .path .join (folder , self .clustering_filename ), 'wb' ) as f :
344+ pickle .dump (self .clustering_class , f )
345+
346+ def load (self , folder , components = 'all' ):
347+ """
348+ Loads the different steps of the pipeline
349+
350+ Args:
351+ folder(str): path to folder
352+ components(list or 'all'): List of components to load. Options are: 'embbedded_points',
353+ 'reduced_points', 'vectorizer', 'reducer', and 'clustering_model'. By default, loads
354+ 'all' (you can get all components by listing the class param
355+ TextClustering.components)
356+
357+ """
358+
359+ if components == 'all' or 'embedded_points' in components :
360+ self .embedded_points = np .load (os .path .join (folder , self .embedded_points_filename ),
361+ allow_pickle = True )
362+ if not self .embedded_points .shape :
363+ self .embedded_points = self .embedded_points [()]
364+
365+ if components == 'all' or 'reduced_points' in components :
366+ self .reduced_points = np .load (os .path .join (folder , self .reduced_points_filename ),
367+ allow_pickle = True )
368+
369+ if components == 'all' or 'vectorizer' in components :
370+ with open (os .path .join (folder , self .vectorizer_filename ), 'rb' ) as f :
371+ self .vectorizer = pickle .load (f )
372+
373+ if components == 'all' or 'reducer' in components :
374+ with open (os .path .join (folder , self .reducer_filename ), 'rb' ) as f :
375+ self .reducer_class = pickle .load (f )
376+
377+ if components == 'all' or 'clustering_model' in components :
378+ with open (os .path .join (folder , self .clustering_filename ), 'rb' ) as f :
379+ self .clustering_class = pickle .load (f )
380+
295381 def stability (self ):
296382 """Function to calculate how stable the clusters are"""
297383 raise NotImplementedError
0 commit comments