Skip to content

Commit 1815846

Browse files
author
Alexander Hillsley
committed
Merge branch 'features'
2 parents 6702044 + 4a4bddf commit 1815846

File tree

10 files changed

+1868
-313
lines changed

10 files changed

+1868
-313
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from tqdm import tqdm
2+
import yaml
3+
4+
import torch
5+
import numpy as np
6+
import anndata as ad
7+
import torch.nn.functional as F
8+
9+
from ops_model.data.paths import OpsPaths
10+
11+
12+
def mean_similarity(
13+
adata,
14+
n_samples=10_000_000,
15+
batch_size=10000,
16+
):
17+
"""
18+
Compute mean and std of pairwise cosine similarities.
19+
20+
Args:
21+
adata: AnnData object with embeddings in .X
22+
use_sampling: If True, sample random pairs instead of computing all pairs
23+
n_samples: Number of pairs to sample (only used if use_sampling=True)
24+
batch_size: Batch size for processing (only used if use_sampling=False)
25+
"""
26+
embeddings = torch.tensor(adata.X).cuda()
27+
x = F.normalize(embeddings, dim=1)
28+
n = x.shape[0]
29+
30+
# Sampling approach: much faster for large datasets
31+
# Sample random pairs and compute their similarities
32+
n_samples = min(n_samples, n * (n - 1) // 2) # Don't sample more than total pairs
33+
34+
# Generate random pairs
35+
idx_i = torch.randint(0, n, (n_samples,), device="cuda")
36+
idx_j = torch.randint(0, n, (n_samples,), device="cuda")
37+
38+
# Ensure i != j
39+
mask = idx_i == idx_j
40+
idx_j[mask] = (idx_j[mask] + 1) % n
41+
42+
# Compute similarities for sampled pairs in batches
43+
similarities = []
44+
for start in tqdm(range(0, n_samples, batch_size)):
45+
end = min(start + batch_size, n_samples)
46+
batch_i = x[idx_i[start:end]]
47+
batch_j = x[idx_j[start:end]]
48+
sim = (batch_i * batch_j).sum(dim=1)
49+
similarities.append(sim.cpu())
50+
51+
similarities = torch.cat(similarities)
52+
mean_similarity = similarities.mean().item()
53+
std_similarity = similarities.std().item()
54+
55+
return mean_similarity, std_similarity
56+
57+
58+
def alignment_and_uniformity(adata, n_uniformity_samples=1_000_000, batch_size=10000):
59+
"""
60+
Compute alignment and uniformity metrics for embeddings.
61+
62+
Code adapted from:
63+
title={Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere},
64+
author={Wang, Tongzhou and Isola, Phillip},
65+
booktitle={International Conference on Machine Learning},
66+
organization={PMLR},
67+
pages={9929--9939},
68+
year={2020}
69+
70+
Args:
71+
adata: AnnData object with embeddings and gene_int labels
72+
n_uniformity_samples: Number of random pairs to sample for uniformity
73+
batch_size: Batch size for processing
74+
"""
75+
gene_int_list = adata.obs["label_int"].unique().tolist()
76+
x = [] # positive pair i
77+
y = [] # positive pair j
78+
for i in tqdm(gene_int_list):
79+
single_gene_embs = adata.X[adata.obs["label_int"] == i]
80+
x += [single_gene_embs[j] for j in range(len(single_gene_embs) - 1)]
81+
y += [
82+
single_gene_embs[z]
83+
for z in np.random.permutation(np.arange(len(single_gene_embs) - 1))
84+
]
85+
86+
x = torch.tensor(np.asarray(x)).cuda()
87+
y = torch.tensor(np.asarray(y)).cuda()
88+
89+
# Compute alignment (all positive pairs)
90+
alignment = (x - y).norm(p=2, dim=1).pow(2).mean().item()
91+
92+
# Compute uniformity using sampling to avoid OOM
93+
n = x.shape[0]
94+
n_samples = min(n_uniformity_samples, n * (n - 1) // 2)
95+
96+
# Sample random pairs for uniformity
97+
idx_i = torch.randint(0, n, (n_samples,), device="cuda")
98+
idx_j = torch.randint(0, n, (n_samples,), device="cuda")
99+
100+
# Ensure i != j
101+
mask = idx_i == idx_j
102+
idx_j[mask] = (idx_j[mask] + 1) % n
103+
104+
# Compute pairwise distances for sampled pairs in batches
105+
uniformity_vals = []
106+
for start in tqdm(range(0, n_samples, batch_size)):
107+
end = min(start + batch_size, n_samples)
108+
batch_i = x[idx_i[start:end]]
109+
batch_j = x[idx_j[start:end]]
110+
# Compute squared L2 distance
111+
dist_sq = (batch_i - batch_j).norm(p=2, dim=1).pow(2)
112+
uniformity_vals.append(dist_sq.cpu())
113+
114+
uniformity_vals = torch.cat(uniformity_vals)
115+
uniformity = uniformity_vals.mul(-2).exp().mean().log().item()
116+
117+
return alignment, uniformity
118+
119+
120+
def compute_embedding_metrics(experiment):
121+
path = OpsPaths(experiment).cell_profiler_out
122+
save_dir = path.parent / "anndata_objects"
123+
assert save_dir.exists(), f"Anndata objects directory does not exist: {save_dir}"
124+
checkpoint_path = save_dir / "features_processed.h5ad"
125+
adata = ad.read_h5ad(checkpoint_path)
126+
assert "X_umap" in adata.obsm, "UMAP embeddings not found in AnnData object."
127+
plots_dir = OpsPaths(experiment).embedding_plot_dir
128+
plots_dir.mkdir(parents=True, exist_ok=True)
129+
130+
mean_sim, std_sim = mean_similarity(adata, n_samples=10_000)
131+
132+
# alignment, uniformity = alignment_and_uniformity(adata, n_uniformity_samples=30_000, batch_size=10_000)
133+
134+
results = {
135+
"mean_cosine_similarity": mean_sim,
136+
"std_cosine_similarity": std_sim,
137+
# "alignment": alignment,
138+
# "uniformity": uniformity,
139+
}
140+
with open(plots_dir / "embedding_metrics.yaml", "w") as f:
141+
yaml.dump(results, f)
142+
143+
return

0 commit comments

Comments
 (0)