Skip to content

Commit 454a3b0

Browse files
author
Alexander Hillsley
committed
add more evaluation metrics
1 parent 05f168a commit 454a3b0

File tree

4 files changed

+867
-89
lines changed

4 files changed

+867
-89
lines changed
Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
# %%
2+
from tqdm import tqdm
3+
4+
import torch
5+
import torch.nn.functional as F
6+
import matplotlib.pyplot as plt
7+
8+
from ops_model.data.embeddings.utils import load_adata
9+
10+
11+
def embedding_spread(adata, label, plot=True):
12+
"""
13+
Calculates the mean cosine similarity from each embedding to the centroid of the class
14+
for all embeddings with the specified label in adata.obs['label_str'].
15+
16+
Returns:
17+
- the average cosine similarity to the centroid
18+
- a histogram of the cosine similarities
19+
20+
"""
21+
22+
# Filter observations with the specified label
23+
mask = adata.obs["label_str"] == label
24+
embeddings = torch.tensor(adata.X[mask]).cuda()
25+
26+
if embeddings.shape[0] < 2:
27+
print(f"Not enough samples for label '{label}': {embeddings.shape[0]}")
28+
return None, None
29+
30+
# Normalize embeddings for cosine similarity computation
31+
x = F.normalize(embeddings, dim=1)
32+
33+
# Compute the centroid (mean of normalized embeddings, then re-normalize)
34+
centroid = x.mean(dim=0, keepdim=True)
35+
centroid = F.normalize(centroid, dim=1)
36+
37+
# Compute cosine similarity from each embedding to the centroid
38+
cosine_similarities = (x * centroid).sum(dim=1).cpu()
39+
mean_similarity = cosine_similarities.mean().item()
40+
std_similarity = cosine_similarities.std().item()
41+
42+
if plot:
43+
# Create histogram
44+
fig, ax = plt.subplots(figsize=(8, 5))
45+
ax.hist(cosine_similarities.numpy(), bins=50, edgecolor="black", alpha=0.7)
46+
ax.set_xlabel("Cosine Similarity to Centroid")
47+
ax.set_ylabel("Frequency")
48+
ax.set_title(f"Cosine Similarity to Centroid for Label: {label}")
49+
ax.axvline(
50+
mean_similarity,
51+
color="red",
52+
linestyle="--",
53+
linewidth=2,
54+
label=f"Mean: {mean_similarity:.4f}",
55+
)
56+
ax.legend()
57+
plt.tight_layout()
58+
plt.close(fig)
59+
60+
return mean_similarity, std_similarity, fig
61+
62+
return mean_similarity, std_similarity, None
63+
64+
65+
def embedding_spread_all_labels(adata, min_samples=2):
66+
"""
67+
Calculates the embedding spread (mean cosine similarity to centroid) for all labels
68+
in the adata object.
69+
70+
Args:
71+
adata: AnnData object with embeddings in .X
72+
min_samples: Minimum number of samples required to compute spread (default: 2)
73+
74+
Returns:
75+
results_dict: Dictionary mapping label names to (mean_similarity, std_similarity) tuples
76+
sorted_results: List of (label, mean_similarity, std_similarity) tuples sorted by mean similarity
77+
fig: Matplotlib figure with histogram of mean similarities across all labels
78+
"""
79+
import matplotlib.pyplot as plt
80+
import numpy as np
81+
82+
# Get all unique labels
83+
unique_labels = adata.obs["label_str"].unique()
84+
85+
results_dict = {}
86+
87+
for label in tqdm(unique_labels):
88+
# Check if label has enough samples
89+
n_samples = (adata.obs["label_str"] == label).sum()
90+
91+
if n_samples < min_samples:
92+
print(f"Skipping label '{label}': only {n_samples} samples")
93+
continue
94+
95+
# Compute embedding spread for this label
96+
mean_sim, std_sim, _ = embedding_spread(adata, label, plot=False)
97+
98+
if mean_sim is not None:
99+
results_dict[label] = (mean_sim, std_sim)
100+
101+
# Sort by mean similarity (highest similarity = most compact clusters first)
102+
sorted_results = sorted(
103+
[(label, mean, std) for label, (mean, std) in results_dict.items()],
104+
key=lambda x: x[1],
105+
reverse=True,
106+
)
107+
108+
# Create histogram of mean similarities
109+
mean_similarities = [mean for _, mean, _ in sorted_results]
110+
overall_mean = np.mean(mean_similarities)
111+
112+
top_10_tightest = sorted_results[:10]
113+
top_10_diffuse = sorted_results[-10:]
114+
115+
fig, ax = plt.subplots(figsize=(8, 5))
116+
ax.hist(mean_similarities, bins=30, edgecolor="black", alpha=0.7)
117+
ax.set_xlabel("Mean Cosine Similarity to Centroid")
118+
ax.set_ylabel("Number of Labels")
119+
ax.set_title("Distribution of Embedding Spread Across Labels")
120+
ax.axvline(
121+
overall_mean,
122+
color="red",
123+
linestyle="--",
124+
linewidth=2,
125+
label=f"Mean: {overall_mean:.4f}",
126+
)
127+
ax.legend()
128+
plt.tight_layout()
129+
130+
return top_10_tightest, top_10_diffuse, sorted_results, fig
131+
132+
133+
def cosine_similarity_to_reference(adata, reference_label):
134+
"""
135+
Calculates the cosine similarity from each label's embedding to a reference label's embedding.
136+
Assumes each label has exactly one embedding in the adata object.
137+
138+
Args:
139+
adata: AnnData object with embeddings in .X (one embedding per label)
140+
reference_label: The reference label to compare against
141+
142+
Returns:
143+
similarities_dict: Dictionary mapping label names to cosine similarities from reference
144+
sorted_labels: List of (label, similarity) tuples sorted by similarity (most similar first)
145+
fig: Matplotlib figure with histogram of similarities
146+
"""
147+
148+
# Get reference embedding
149+
ref_mask = adata.obs["label_str"] == reference_label
150+
if ref_mask.sum() == 0:
151+
raise ValueError(f"Reference label '{reference_label}' not found in adata")
152+
if ref_mask.sum() > 1:
153+
raise ValueError(
154+
f"Reference label '{reference_label}' has multiple embeddings, expected 1"
155+
)
156+
157+
ref_embedding = torch.tensor(adata.X[ref_mask]).cuda()
158+
ref_x = F.normalize(ref_embedding, dim=1)
159+
160+
# Get all embeddings and labels
161+
all_embeddings = torch.tensor(adata.X).cuda()
162+
all_x = F.normalize(all_embeddings, dim=1)
163+
164+
# Compute cosine similarity between reference and all embeddings
165+
cosine_similarities = (ref_x @ all_x.T).squeeze().cpu()
166+
167+
# Create dictionary mapping labels to similarities
168+
similarities_dict = {}
169+
for idx, label in enumerate(adata.obs["label_str"]):
170+
similarities_dict[label] = cosine_similarities[idx].item()
171+
172+
# Sort labels by similarity (highest first)
173+
sorted_labels = sorted(similarities_dict.items(), key=lambda x: x[1], reverse=True)
174+
175+
# Calculate mean similarity
176+
mean_similarity = cosine_similarities.mean().item()
177+
178+
top_10_closest = sorted_labels[1:11]
179+
top_10_furthest = sorted_labels[-10:]
180+
181+
# Create histogram
182+
fig, ax = plt.subplots(figsize=(8, 5))
183+
ax.hist(cosine_similarities.numpy(), bins=50, edgecolor="black", alpha=0.7)
184+
ax.set_xlabel("Cosine Similarity")
185+
ax.set_ylabel("Frequency")
186+
ax.set_title(f"Cosine Similarity Distribution to Reference: {reference_label}")
187+
ax.axvline(
188+
mean_similarity,
189+
color="red",
190+
linestyle="--",
191+
linewidth=2,
192+
label=f"Mean: {mean_similarity:.4f}",
193+
)
194+
ax.legend()
195+
plt.tight_layout()
196+
197+
return top_10_closest, top_10_furthest, sorted_labels, fig
198+
199+
200+
def mean_similarity(
201+
adata,
202+
n_samples=10_000_000,
203+
batch_size=10000,
204+
):
205+
"""
206+
Compute mean and std of pairwise cosine similarities.
207+
208+
Args:
209+
adata: AnnData object with embeddings in .X
210+
use_sampling: If True, sample random pairs instead of computing all pairs
211+
n_samples: Number of pairs to sample (only used if use_sampling=True)
212+
batch_size: Batch size for processing (only used if use_sampling=False)
213+
"""
214+
embeddings = torch.tensor(adata.X).cuda()
215+
x = F.normalize(embeddings, dim=1)
216+
n = x.shape[0]
217+
218+
# Sampling approach: much faster for large datasets
219+
# Sample random pairs and compute their similarities
220+
n_samples = min(n_samples, n * (n - 1) // 2) # Don't sample more than total pairs
221+
222+
# Generate random pairs
223+
idx_i = torch.randint(0, n, (n_samples,), device="cuda")
224+
idx_j = torch.randint(0, n, (n_samples,), device="cuda")
225+
226+
# Ensure i != j
227+
mask = idx_i == idx_j
228+
idx_j[mask] = (idx_j[mask] + 1) % n
229+
230+
# Compute similarities for sampled pairs in batches
231+
similarities = []
232+
for start in tqdm(range(0, n_samples, batch_size)):
233+
end = min(start + batch_size, n_samples)
234+
batch_i = x[idx_i[start:end]]
235+
batch_j = x[idx_j[start:end]]
236+
sim = (batch_i * batch_j).sum(dim=1)
237+
similarities.append(sim.cpu())
238+
239+
similarities = torch.cat(similarities)
240+
mean_similarity = similarities.mean().item()
241+
std_similarity = similarities.std().item()
242+
243+
fig, ax = plt.subplots(figsize=(8, 5))
244+
ax.hist(similarities, bins=100, edgecolor="black", alpha=0.7)
245+
ax.set_xlabel("Cosine Similarity")
246+
ax.set_ylabel("Number of Labels")
247+
ax.set_title("Cosine Similarity Distribution Across Labels")
248+
ax.axvline(
249+
mean_similarity,
250+
color="red",
251+
linestyle="--",
252+
linewidth=2,
253+
label=f"Mean: {mean_similarity:.4f}",
254+
)
255+
ax.legend()
256+
plt.tight_layout()
257+
plt.close(fig)
258+
259+
return mean_similarity, std_similarity, fig
260+
261+
262+
def mean_similarity_within_labels(
263+
adata,
264+
n_samples_per_label=10000,
265+
batch_size=10000,
266+
):
267+
"""
268+
Compute mean and std of pairwise cosine similarities for pairs within the same label.
269+
Only computes similarities between embeddings that share the same label in adata.obs['label_str'].
270+
271+
Args:
272+
adata: AnnData object with embeddings in .X
273+
n_samples_per_label: Number of pairs to sample per label
274+
batch_size: Batch size for processing
275+
276+
Returns:
277+
mean_similarity: Mean cosine similarity across all within-label pairs
278+
std_similarity: Standard deviation of cosine similarities
279+
"""
280+
import numpy as np
281+
282+
embeddings = torch.tensor(adata.X).cuda()
283+
x = F.normalize(embeddings, dim=1)
284+
285+
# Get unique labels
286+
unique_labels = adata.obs["label_str"].unique()
287+
288+
all_similarities = []
289+
290+
for label in tqdm(unique_labels, desc="Processing labels"):
291+
# Get indices for this label
292+
label_mask = adata.obs["label_str"] == label
293+
label_indices = np.where(label_mask)[0]
294+
n_label = len(label_indices)
295+
296+
# Skip labels with only one sample
297+
if n_label < 2:
298+
continue
299+
300+
# Determine number of pairs to sample for this label
301+
n_possible_pairs = n_label * (n_label - 1) // 2
302+
n_samples = min(n_samples_per_label, n_possible_pairs)
303+
304+
# Convert label indices to tensor
305+
label_indices_tensor = torch.tensor(label_indices, device="cuda")
306+
307+
# Generate random pairs within this label
308+
idx_i = torch.randint(0, n_label, (n_samples,), device="cuda")
309+
idx_j = torch.randint(0, n_label, (n_samples,), device="cuda")
310+
311+
# Ensure i != j
312+
mask = idx_i == idx_j
313+
idx_j[mask] = (idx_j[mask] + 1) % n_label
314+
315+
# Map to actual indices in the full dataset
316+
actual_idx_i = label_indices_tensor[idx_i]
317+
actual_idx_j = label_indices_tensor[idx_j]
318+
319+
# Compute similarities for sampled pairs in batches
320+
for start in range(0, n_samples, batch_size):
321+
end = min(start + batch_size, n_samples)
322+
batch_i = x[actual_idx_i[start:end]]
323+
batch_j = x[actual_idx_j[start:end]]
324+
sim = (batch_i * batch_j).sum(dim=1)
325+
all_similarities.append(sim.cpu())
326+
327+
# Combine all similarities
328+
all_similarities = torch.cat(all_similarities)
329+
mean_similarity = all_similarities.mean().item()
330+
std_similarity = all_similarities.std().item()
331+
332+
fig, ax = plt.subplots(figsize=(8, 5))
333+
ax.hist(all_similarities, bins=100, edgecolor="black", alpha=0.7)
334+
ax.set_xlabel("Cosine Similarity")
335+
ax.set_ylabel("Number of Labels")
336+
ax.set_title("Cosine Similarity Distribution Within Labels")
337+
ax.axvline(
338+
mean_similarity,
339+
color="red",
340+
linestyle="--",
341+
linewidth=2,
342+
label=f"Mean: {mean_similarity:.4f}",
343+
)
344+
ax.legend()
345+
plt.tight_layout()
346+
plt.close(fig)
347+
348+
return mean_similarity, std_similarity, fig
349+
350+
351+
if __name__ == "__main__":
352+
adata_path = "/hpc/projects/intracellular_dashboard/ops/ops0031_20250424/3-assembly/dynaclr_features"
353+
adata_cells, adata_guides, adata_genes = load_adata(adata_path)
354+
355+
mean_similarity_all, std_similarity_all, across_labels_fig = mean_similarity(
356+
adata_cells, n_samples=1_000_000
357+
)
358+
mean_similarity_within, std_similarity_within, within_labels_fig = (
359+
mean_similarity_within_labels(adata_cells, n_samples_per_label=10_000)
360+
)
361+
362+
top_10_closest, top_10_furthest, sorted_labels, fig = (
363+
cosine_similarity_to_reference(adata_genes, reference_label="NTC")
364+
)
365+
top_10_tightest, top_10_diffuse, spread_sorted_labels, spread_fig = (
366+
embedding_spread_all_labels(adata_cells, min_samples=2)
367+
)
368+
369+
print("Mean Cosine Similarity (All Cells):", f"{mean_similarity_all:.4f}")
370+
print("Std Dev Cosine Similarity (All Cells):", f"{std_similarity_all:.4f}")
371+
print("\nTop 10 Closest Labels to NTC:")
372+
for label, sim in top_10_closest:
373+
print(f" {label}: {sim:.4f}")
374+
print("\nTop 10 Furthest Labels from NTC:")
375+
for label, sim in top_10_furthest:
376+
print(f" {label}: {sim:.4f}")
377+
print("\nTop 10 Tightest Embedding Spreads:")
378+
for label, mean_sim, std_sim in top_10_tightest:
379+
print(f" {label}: Mean Similarity = {mean_sim:.4f}, Std Dev = {std_sim:.4f}")
380+
print("\nTop 10 Most Diffuse Embedding Spreads:")
381+
for label, mean_sim, std_sim in top_10_diffuse:
382+
print(f" {label}: Mean Similarity = {mean_sim:.4f}, Std Dev = {std_sim:.4f}")
383+
384+
385+
"""
386+
Notes
387+
- Cosine Similarity across labels should be a broad distribution centered around 0
388+
- Cosine Similarity within labels should be skewed towards 1, indicating that embeddings
389+
sharing the same label are more similar to each other
390+
391+
- TODO: how do we test the within / across labels with bulking?
392+
"""

0 commit comments

Comments
 (0)