Skip to content

Commit 269c7c9

Browse files
committed
Fixing t-SNE parameters and incrementing version to 1.0.5
1 parent e8afe20 commit 269c7c9

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

chemfunc/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Contains the version information for chemfunc."""
22
# major, minor, patch
3-
version_info = 1, 0, 4
3+
version_info = 1, 0, 5
44

55
# Nice string for the version
66
__version__ = '.'.join(map(str, version_info))

chemfunc/plot_tsne.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,20 @@
1616
def plot_tsne(
1717
data_paths: list[Path],
1818
save_path: Path,
19-
method: Literal['t-SNE'] = 't-SNE',
2019
metric: Literal['jaccard', 'euclidean'] = 'jaccard',
2120
embedder: Literal['morgan', 'file'] = 'morgan',
2221
max_molecules: list[int] | None = None,
2322
colors: list[str] | None = None,
2423
smiles_columns: list[str] | None = None,
2524
data_names: list[str] | None = None,
2625
highlight_data_names: set[str] | None = None,
27-
display_data_names: set[str] | None = None
26+
display_data_names: set[str] | None = None,
27+
point_size: int = 1000
2828
) -> None:
2929
"""Runs a t-SNE on molecular fingerprints from one or more chemical libraries.
3030
3131
:param data_paths: Path to CSV files containing SMILES.
32-
:param save_path: Path to a PDF file where the dimensionality reduction plot will be saved.
33-
:param method: Dimensionality reduction method.
32+
:param save_path: Path to a PDF file where the t-SNE plot will be saved.
3433
:param metric: Metric to use to compared embeddings.
3534
:param embedder: Embedding to use for the molecules.
3635
morgan: Computes Morgan fingerprint from the SMILES.
@@ -43,6 +42,7 @@ def plot_tsne(
4342
:param data_names: Names of the data files for labeling the plot.
4443
:param highlight_data_names: Names of the data files to highlight in the plot.
4544
:param display_data_names: The names of the data files to display in the plot. If None, all are displayed.
45+
:param point_size: The size of the points in the plot.
4646
"""
4747
# Validate max_molecules
4848
if max_molecules is None:
@@ -87,8 +87,9 @@ def plot_tsne(
8787

8888
# Load data and subsample SMILES
8989
smiles, slices, embeddings = [], [], []
90-
for data_path, smiles_column, data_name, max_mols in tqdm(zip(data_paths, smiles_columns, data_names, max_molecules),
91-
total=len(data_paths), desc='Loading data'):
90+
for data_path, smiles_column, data_name, max_mols in tqdm(
91+
zip(data_paths, smiles_columns, data_names, max_molecules),
92+
total=len(data_paths), desc='Loading data'):
9293
# Load data
9394
data = pd.read_csv(data_path)
9495
print(f'{data_name}: {len(data):,}')
@@ -118,15 +119,12 @@ def plot_tsne(
118119
else:
119120
raise ValueError(f'Embedder "{embedder}" is not supported.')
120121

121-
# Run dimensionality reduction
122-
if method == 't-SNE':
123-
reducer = TSNE(random_state=0, metric=metric, init='pca', n_jobs=-1, square_distances=True)
124-
else:
125-
raise ValueError(f'Dimensionality reduction method "{method}" is not supported.')
122+
# Run t-SNE
123+
tsne = TSNE(random_state=0, metric=metric, init='pca', n_jobs=-1)
126124

127-
print(f'Running {method}')
125+
print(f'Running t-SNE')
128126
start = time.time()
129-
X = reducer.fit_transform(embeddings)
127+
X = tsne.fit_transform(embeddings)
130128
print(f'time = {time.time() - start:.2f} seconds')
131129

132130
print('Plotting')
@@ -135,15 +133,15 @@ def plot_tsne(
135133

136134
plt.clf()
137135
plt.figure(figsize=(64, 48))
138-
plt.title(f'{method} using Morgan fingerprint with {metric.title()} similarity', fontsize=100)
136+
plt.title(f't-SNE using Morgan fingerprint with {metric.title()} similarity', fontsize=100)
139137

140138
tsne_data = {}
141139
for index, (slc, data_name) in enumerate(zip(slices, data_names)):
142140
if display_data_names is None or data_name in display_data_names:
143141
plt.scatter(
144142
X[slc, 0],
145143
X[slc, 1],
146-
s=1500 if data_name in highlight_data_names else 1000,
144+
s=1.5 * point_size if data_name in highlight_data_names else point_size,
147145
color=colors[index],
148146
label=data_name,
149147
marker='*' if data_name in highlight_data_names else '.'

0 commit comments

Comments
 (0)