1616def plot_tsne (
1717 data_paths : list [Path ],
1818 save_path : Path ,
19- method : Literal ['t-SNE' ] = 't-SNE' ,
2019 metric : Literal ['jaccard' , 'euclidean' ] = 'jaccard' ,
2120 embedder : Literal ['morgan' , 'file' ] = 'morgan' ,
2221 max_molecules : list [int ] | None = None ,
2322 colors : list [str ] | None = None ,
2423 smiles_columns : list [str ] | None = None ,
2524 data_names : list [str ] | None = None ,
2625 highlight_data_names : set [str ] | None = None ,
27- display_data_names : set [str ] | None = None
26+ display_data_names : set [str ] | None = None ,
27+ point_size : int = 1000
2828) -> None :
2929 """Runs a t-SNE on molecular fingerprints from one or more chemical libraries.
3030
3131 :param data_paths: Path to CSV files containing SMILES.
32- :param save_path: Path to a PDF file where the dimensionality reduction plot will be saved.
33- :param method: Dimensionality reduction method.
32+ :param save_path: Path to a PDF file where the t-SNE plot will be saved.
3433 :param metric: Metric to use to compared embeddings.
3534 :param embedder: Embedding to use for the molecules.
3635 morgan: Computes Morgan fingerprint from the SMILES.
@@ -43,6 +42,7 @@ def plot_tsne(
4342 :param data_names: Names of the data files for labeling the plot.
4443 :param highlight_data_names: Names of the data files to highlight in the plot.
4544 :param display_data_names: The names of the data files to display in the plot. If None, all are displayed.
45+ :param point_size: The size of the points in the plot.
4646 """
4747 # Validate max_molecules
4848 if max_molecules is None :
@@ -87,8 +87,9 @@ def plot_tsne(
8787
8888 # Load data and subsample SMILES
8989 smiles , slices , embeddings = [], [], []
90- for data_path , smiles_column , data_name , max_mols in tqdm (zip (data_paths , smiles_columns , data_names , max_molecules ),
91- total = len (data_paths ), desc = 'Loading data' ):
90+ for data_path , smiles_column , data_name , max_mols in tqdm (
91+ zip (data_paths , smiles_columns , data_names , max_molecules ),
92+ total = len (data_paths ), desc = 'Loading data' ):
9293 # Load data
9394 data = pd .read_csv (data_path )
9495 print (f'{ data_name } : { len (data ):,} ' )
@@ -118,15 +119,12 @@ def plot_tsne(
118119 else :
119120 raise ValueError (f'Embedder "{ embedder } " is not supported.' )
120121
121- # Run dimensionality reduction
122- if method == 't-SNE' :
123- reducer = TSNE (random_state = 0 , metric = metric , init = 'pca' , n_jobs = - 1 , square_distances = True )
124- else :
125- raise ValueError (f'Dimensionality reduction method "{ method } " is not supported.' )
122+ # Run t-SNE
123+ tsne = TSNE (random_state = 0 , metric = metric , init = 'pca' , n_jobs = - 1 )
126124
127- print (f'Running { method } ' )
125+ print (f'Running t-SNE ' )
128126 start = time .time ()
129- X = reducer .fit_transform (embeddings )
127+ X = tsne .fit_transform (embeddings )
130128 print (f'time = { time .time () - start :.2f} seconds' )
131129
132130 print ('Plotting' )
@@ -135,15 +133,15 @@ def plot_tsne(
135133
136134 plt .clf ()
137135 plt .figure (figsize = (64 , 48 ))
138- plt .title (f'{ method } using Morgan fingerprint with { metric .title ()} similarity' , fontsize = 100 )
136+ plt .title (f't-SNE using Morgan fingerprint with { metric .title ()} similarity' , fontsize = 100 )
139137
140138 tsne_data = {}
141139 for index , (slc , data_name ) in enumerate (zip (slices , data_names )):
142140 if display_data_names is None or data_name in display_data_names :
143141 plt .scatter (
144142 X [slc , 0 ],
145143 X [slc , 1 ],
146- s = 1500 if data_name in highlight_data_names else 1000 ,
144+ s = 1.5 * point_size if data_name in highlight_data_names else point_size ,
147145 color = colors [index ],
148146 label = data_name ,
149147 marker = '*' if data_name in highlight_data_names else '.'
0 commit comments