|
| 1 | +# %% |
| 2 | +from tqdm import tqdm |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +import matplotlib.pyplot as plt |
| 7 | + |
| 8 | +from ops_model.data.embeddings.utils import load_adata |
| 9 | + |
| 10 | + |
| 11 | +def embedding_spread(adata, label, plot=True): |
| 12 | + """ |
| 13 | + Calculates the mean cosine similarity from each embedding to the centroid of the class |
| 14 | + for all embeddings with the specified label in adata.obs['label_str']. |
| 15 | +
|
| 16 | + Returns: |
| 17 | + - the average cosine similarity to the centroid |
| 18 | + - a histogram of the cosine similarities |
| 19 | +
|
| 20 | + """ |
| 21 | + |
| 22 | + # Filter observations with the specified label |
| 23 | + mask = adata.obs["label_str"] == label |
| 24 | + embeddings = torch.tensor(adata.X[mask]).cuda() |
| 25 | + |
| 26 | + if embeddings.shape[0] < 2: |
| 27 | + print(f"Not enough samples for label '{label}': {embeddings.shape[0]}") |
| 28 | + return None, None |
| 29 | + |
| 30 | + # Normalize embeddings for cosine similarity computation |
| 31 | + x = F.normalize(embeddings, dim=1) |
| 32 | + |
| 33 | + # Compute the centroid (mean of normalized embeddings, then re-normalize) |
| 34 | + centroid = x.mean(dim=0, keepdim=True) |
| 35 | + centroid = F.normalize(centroid, dim=1) |
| 36 | + |
| 37 | + # Compute cosine similarity from each embedding to the centroid |
| 38 | + cosine_similarities = (x * centroid).sum(dim=1).cpu() |
| 39 | + mean_similarity = cosine_similarities.mean().item() |
| 40 | + std_similarity = cosine_similarities.std().item() |
| 41 | + |
| 42 | + if plot: |
| 43 | + # Create histogram |
| 44 | + fig, ax = plt.subplots(figsize=(8, 5)) |
| 45 | + ax.hist(cosine_similarities.numpy(), bins=50, edgecolor="black", alpha=0.7) |
| 46 | + ax.set_xlabel("Cosine Similarity to Centroid") |
| 47 | + ax.set_ylabel("Frequency") |
| 48 | + ax.set_title(f"Cosine Similarity to Centroid for Label: {label}") |
| 49 | + ax.axvline( |
| 50 | + mean_similarity, |
| 51 | + color="red", |
| 52 | + linestyle="--", |
| 53 | + linewidth=2, |
| 54 | + label=f"Mean: {mean_similarity:.4f}", |
| 55 | + ) |
| 56 | + ax.legend() |
| 57 | + plt.tight_layout() |
| 58 | + plt.close(fig) |
| 59 | + |
| 60 | + return mean_similarity, std_similarity, fig |
| 61 | + |
| 62 | + return mean_similarity, std_similarity, None |
| 63 | + |
| 64 | + |
| 65 | +def embedding_spread_all_labels(adata, min_samples=2): |
| 66 | + """ |
| 67 | + Calculates the embedding spread (mean cosine similarity to centroid) for all labels |
| 68 | + in the adata object. |
| 69 | +
|
| 70 | + Args: |
| 71 | + adata: AnnData object with embeddings in .X |
| 72 | + min_samples: Minimum number of samples required to compute spread (default: 2) |
| 73 | +
|
| 74 | + Returns: |
| 75 | + results_dict: Dictionary mapping label names to (mean_similarity, std_similarity) tuples |
| 76 | + sorted_results: List of (label, mean_similarity, std_similarity) tuples sorted by mean similarity |
| 77 | + fig: Matplotlib figure with histogram of mean similarities across all labels |
| 78 | + """ |
| 79 | + import matplotlib.pyplot as plt |
| 80 | + import numpy as np |
| 81 | + |
| 82 | + # Get all unique labels |
| 83 | + unique_labels = adata.obs["label_str"].unique() |
| 84 | + |
| 85 | + results_dict = {} |
| 86 | + |
| 87 | + for label in tqdm(unique_labels): |
| 88 | + # Check if label has enough samples |
| 89 | + n_samples = (adata.obs["label_str"] == label).sum() |
| 90 | + |
| 91 | + if n_samples < min_samples: |
| 92 | + print(f"Skipping label '{label}': only {n_samples} samples") |
| 93 | + continue |
| 94 | + |
| 95 | + # Compute embedding spread for this label |
| 96 | + mean_sim, std_sim, _ = embedding_spread(adata, label, plot=False) |
| 97 | + |
| 98 | + if mean_sim is not None: |
| 99 | + results_dict[label] = (mean_sim, std_sim) |
| 100 | + |
| 101 | + # Sort by mean similarity (highest similarity = most compact clusters first) |
| 102 | + sorted_results = sorted( |
| 103 | + [(label, mean, std) for label, (mean, std) in results_dict.items()], |
| 104 | + key=lambda x: x[1], |
| 105 | + reverse=True, |
| 106 | + ) |
| 107 | + |
| 108 | + # Create histogram of mean similarities |
| 109 | + mean_similarities = [mean for _, mean, _ in sorted_results] |
| 110 | + overall_mean = np.mean(mean_similarities) |
| 111 | + |
| 112 | + top_10_tightest = sorted_results[:10] |
| 113 | + top_10_diffuse = sorted_results[-10:] |
| 114 | + |
| 115 | + fig, ax = plt.subplots(figsize=(8, 5)) |
| 116 | + ax.hist(mean_similarities, bins=30, edgecolor="black", alpha=0.7) |
| 117 | + ax.set_xlabel("Mean Cosine Similarity to Centroid") |
| 118 | + ax.set_ylabel("Number of Labels") |
| 119 | + ax.set_title("Distribution of Embedding Spread Across Labels") |
| 120 | + ax.axvline( |
| 121 | + overall_mean, |
| 122 | + color="red", |
| 123 | + linestyle="--", |
| 124 | + linewidth=2, |
| 125 | + label=f"Mean: {overall_mean:.4f}", |
| 126 | + ) |
| 127 | + ax.legend() |
| 128 | + plt.tight_layout() |
| 129 | + |
| 130 | + return top_10_tightest, top_10_diffuse, sorted_results, fig |
| 131 | + |
| 132 | + |
| 133 | +def cosine_similarity_to_reference(adata, reference_label): |
| 134 | + """ |
| 135 | + Calculates the cosine similarity from each label's embedding to a reference label's embedding. |
| 136 | + Assumes each label has exactly one embedding in the adata object. |
| 137 | +
|
| 138 | + Args: |
| 139 | + adata: AnnData object with embeddings in .X (one embedding per label) |
| 140 | + reference_label: The reference label to compare against |
| 141 | +
|
| 142 | + Returns: |
| 143 | + similarities_dict: Dictionary mapping label names to cosine similarities from reference |
| 144 | + sorted_labels: List of (label, similarity) tuples sorted by similarity (most similar first) |
| 145 | + fig: Matplotlib figure with histogram of similarities |
| 146 | + """ |
| 147 | + |
| 148 | + # Get reference embedding |
| 149 | + ref_mask = adata.obs["label_str"] == reference_label |
| 150 | + if ref_mask.sum() == 0: |
| 151 | + raise ValueError(f"Reference label '{reference_label}' not found in adata") |
| 152 | + if ref_mask.sum() > 1: |
| 153 | + raise ValueError( |
| 154 | + f"Reference label '{reference_label}' has multiple embeddings, expected 1" |
| 155 | + ) |
| 156 | + |
| 157 | + ref_embedding = torch.tensor(adata.X[ref_mask]).cuda() |
| 158 | + ref_x = F.normalize(ref_embedding, dim=1) |
| 159 | + |
| 160 | + # Get all embeddings and labels |
| 161 | + all_embeddings = torch.tensor(adata.X).cuda() |
| 162 | + all_x = F.normalize(all_embeddings, dim=1) |
| 163 | + |
| 164 | + # Compute cosine similarity between reference and all embeddings |
| 165 | + cosine_similarities = (ref_x @ all_x.T).squeeze().cpu() |
| 166 | + |
| 167 | + # Create dictionary mapping labels to similarities |
| 168 | + similarities_dict = {} |
| 169 | + for idx, label in enumerate(adata.obs["label_str"]): |
| 170 | + similarities_dict[label] = cosine_similarities[idx].item() |
| 171 | + |
| 172 | + # Sort labels by similarity (highest first) |
| 173 | + sorted_labels = sorted(similarities_dict.items(), key=lambda x: x[1], reverse=True) |
| 174 | + |
| 175 | + # Calculate mean similarity |
| 176 | + mean_similarity = cosine_similarities.mean().item() |
| 177 | + |
| 178 | + top_10_closest = sorted_labels[1:11] |
| 179 | + top_10_furthest = sorted_labels[-10:] |
| 180 | + |
| 181 | + # Create histogram |
| 182 | + fig, ax = plt.subplots(figsize=(8, 5)) |
| 183 | + ax.hist(cosine_similarities.numpy(), bins=50, edgecolor="black", alpha=0.7) |
| 184 | + ax.set_xlabel("Cosine Similarity") |
| 185 | + ax.set_ylabel("Frequency") |
| 186 | + ax.set_title(f"Cosine Similarity Distribution to Reference: {reference_label}") |
| 187 | + ax.axvline( |
| 188 | + mean_similarity, |
| 189 | + color="red", |
| 190 | + linestyle="--", |
| 191 | + linewidth=2, |
| 192 | + label=f"Mean: {mean_similarity:.4f}", |
| 193 | + ) |
| 194 | + ax.legend() |
| 195 | + plt.tight_layout() |
| 196 | + |
| 197 | + return top_10_closest, top_10_furthest, sorted_labels, fig |
| 198 | + |
| 199 | + |
| 200 | +def mean_similarity( |
| 201 | + adata, |
| 202 | + n_samples=10_000_000, |
| 203 | + batch_size=10000, |
| 204 | +): |
| 205 | + """ |
| 206 | + Compute mean and std of pairwise cosine similarities. |
| 207 | +
|
| 208 | + Args: |
| 209 | + adata: AnnData object with embeddings in .X |
| 210 | + use_sampling: If True, sample random pairs instead of computing all pairs |
| 211 | + n_samples: Number of pairs to sample (only used if use_sampling=True) |
| 212 | + batch_size: Batch size for processing (only used if use_sampling=False) |
| 213 | + """ |
| 214 | + embeddings = torch.tensor(adata.X).cuda() |
| 215 | + x = F.normalize(embeddings, dim=1) |
| 216 | + n = x.shape[0] |
| 217 | + |
| 218 | + # Sampling approach: much faster for large datasets |
| 219 | + # Sample random pairs and compute their similarities |
| 220 | + n_samples = min(n_samples, n * (n - 1) // 2) # Don't sample more than total pairs |
| 221 | + |
| 222 | + # Generate random pairs |
| 223 | + idx_i = torch.randint(0, n, (n_samples,), device="cuda") |
| 224 | + idx_j = torch.randint(0, n, (n_samples,), device="cuda") |
| 225 | + |
| 226 | + # Ensure i != j |
| 227 | + mask = idx_i == idx_j |
| 228 | + idx_j[mask] = (idx_j[mask] + 1) % n |
| 229 | + |
| 230 | + # Compute similarities for sampled pairs in batches |
| 231 | + similarities = [] |
| 232 | + for start in tqdm(range(0, n_samples, batch_size)): |
| 233 | + end = min(start + batch_size, n_samples) |
| 234 | + batch_i = x[idx_i[start:end]] |
| 235 | + batch_j = x[idx_j[start:end]] |
| 236 | + sim = (batch_i * batch_j).sum(dim=1) |
| 237 | + similarities.append(sim.cpu()) |
| 238 | + |
| 239 | + similarities = torch.cat(similarities) |
| 240 | + mean_similarity = similarities.mean().item() |
| 241 | + std_similarity = similarities.std().item() |
| 242 | + |
| 243 | + fig, ax = plt.subplots(figsize=(8, 5)) |
| 244 | + ax.hist(similarities, bins=100, edgecolor="black", alpha=0.7) |
| 245 | + ax.set_xlabel("Cosine Similarity") |
| 246 | + ax.set_ylabel("Number of Labels") |
| 247 | + ax.set_title("Cosine Similarity Distribution Across Labels") |
| 248 | + ax.axvline( |
| 249 | + mean_similarity, |
| 250 | + color="red", |
| 251 | + linestyle="--", |
| 252 | + linewidth=2, |
| 253 | + label=f"Mean: {mean_similarity:.4f}", |
| 254 | + ) |
| 255 | + ax.legend() |
| 256 | + plt.tight_layout() |
| 257 | + plt.close(fig) |
| 258 | + |
| 259 | + return mean_similarity, std_similarity, fig |
| 260 | + |
| 261 | + |
| 262 | +def mean_similarity_within_labels( |
| 263 | + adata, |
| 264 | + n_samples_per_label=10000, |
| 265 | + batch_size=10000, |
| 266 | +): |
| 267 | + """ |
| 268 | + Compute mean and std of pairwise cosine similarities for pairs within the same label. |
| 269 | + Only computes similarities between embeddings that share the same label in adata.obs['label_str']. |
| 270 | +
|
| 271 | + Args: |
| 272 | + adata: AnnData object with embeddings in .X |
| 273 | + n_samples_per_label: Number of pairs to sample per label |
| 274 | + batch_size: Batch size for processing |
| 275 | +
|
| 276 | + Returns: |
| 277 | + mean_similarity: Mean cosine similarity across all within-label pairs |
| 278 | + std_similarity: Standard deviation of cosine similarities |
| 279 | + """ |
| 280 | + import numpy as np |
| 281 | + |
| 282 | + embeddings = torch.tensor(adata.X).cuda() |
| 283 | + x = F.normalize(embeddings, dim=1) |
| 284 | + |
| 285 | + # Get unique labels |
| 286 | + unique_labels = adata.obs["label_str"].unique() |
| 287 | + |
| 288 | + all_similarities = [] |
| 289 | + |
| 290 | + for label in tqdm(unique_labels, desc="Processing labels"): |
| 291 | + # Get indices for this label |
| 292 | + label_mask = adata.obs["label_str"] == label |
| 293 | + label_indices = np.where(label_mask)[0] |
| 294 | + n_label = len(label_indices) |
| 295 | + |
| 296 | + # Skip labels with only one sample |
| 297 | + if n_label < 2: |
| 298 | + continue |
| 299 | + |
| 300 | + # Determine number of pairs to sample for this label |
| 301 | + n_possible_pairs = n_label * (n_label - 1) // 2 |
| 302 | + n_samples = min(n_samples_per_label, n_possible_pairs) |
| 303 | + |
| 304 | + # Convert label indices to tensor |
| 305 | + label_indices_tensor = torch.tensor(label_indices, device="cuda") |
| 306 | + |
| 307 | + # Generate random pairs within this label |
| 308 | + idx_i = torch.randint(0, n_label, (n_samples,), device="cuda") |
| 309 | + idx_j = torch.randint(0, n_label, (n_samples,), device="cuda") |
| 310 | + |
| 311 | + # Ensure i != j |
| 312 | + mask = idx_i == idx_j |
| 313 | + idx_j[mask] = (idx_j[mask] + 1) % n_label |
| 314 | + |
| 315 | + # Map to actual indices in the full dataset |
| 316 | + actual_idx_i = label_indices_tensor[idx_i] |
| 317 | + actual_idx_j = label_indices_tensor[idx_j] |
| 318 | + |
| 319 | + # Compute similarities for sampled pairs in batches |
| 320 | + for start in range(0, n_samples, batch_size): |
| 321 | + end = min(start + batch_size, n_samples) |
| 322 | + batch_i = x[actual_idx_i[start:end]] |
| 323 | + batch_j = x[actual_idx_j[start:end]] |
| 324 | + sim = (batch_i * batch_j).sum(dim=1) |
| 325 | + all_similarities.append(sim.cpu()) |
| 326 | + |
| 327 | + # Combine all similarities |
| 328 | + all_similarities = torch.cat(all_similarities) |
| 329 | + mean_similarity = all_similarities.mean().item() |
| 330 | + std_similarity = all_similarities.std().item() |
| 331 | + |
| 332 | + fig, ax = plt.subplots(figsize=(8, 5)) |
| 333 | + ax.hist(all_similarities, bins=100, edgecolor="black", alpha=0.7) |
| 334 | + ax.set_xlabel("Cosine Similarity") |
| 335 | + ax.set_ylabel("Number of Labels") |
| 336 | + ax.set_title("Cosine Similarity Distribution Within Labels") |
| 337 | + ax.axvline( |
| 338 | + mean_similarity, |
| 339 | + color="red", |
| 340 | + linestyle="--", |
| 341 | + linewidth=2, |
| 342 | + label=f"Mean: {mean_similarity:.4f}", |
| 343 | + ) |
| 344 | + ax.legend() |
| 345 | + plt.tight_layout() |
| 346 | + plt.close(fig) |
| 347 | + |
| 348 | + return mean_similarity, std_similarity, fig |
| 349 | + |
| 350 | + |
| 351 | +if __name__ == "__main__": |
| 352 | + adata_path = "/hpc/projects/intracellular_dashboard/ops/ops0031_20250424/3-assembly/dynaclr_features" |
| 353 | + adata_cells, adata_guides, adata_genes = load_adata(adata_path) |
| 354 | + |
| 355 | + mean_similarity_all, std_similarity_all, across_labels_fig = mean_similarity( |
| 356 | + adata_cells, n_samples=1_000_000 |
| 357 | + ) |
| 358 | + mean_similarity_within, std_similarity_within, within_labels_fig = ( |
| 359 | + mean_similarity_within_labels(adata_cells, n_samples_per_label=10_000) |
| 360 | + ) |
| 361 | + |
| 362 | + top_10_closest, top_10_furthest, sorted_labels, fig = ( |
| 363 | + cosine_similarity_to_reference(adata_genes, reference_label="NTC") |
| 364 | + ) |
| 365 | + top_10_tightest, top_10_diffuse, spread_sorted_labels, spread_fig = ( |
| 366 | + embedding_spread_all_labels(adata_cells, min_samples=2) |
| 367 | + ) |
| 368 | + |
| 369 | + print("Mean Cosine Similarity (All Cells):", f"{mean_similarity_all:.4f}") |
| 370 | + print("Std Dev Cosine Similarity (All Cells):", f"{std_similarity_all:.4f}") |
| 371 | + print("\nTop 10 Closest Labels to NTC:") |
| 372 | + for label, sim in top_10_closest: |
| 373 | + print(f" {label}: {sim:.4f}") |
| 374 | + print("\nTop 10 Furthest Labels from NTC:") |
| 375 | + for label, sim in top_10_furthest: |
| 376 | + print(f" {label}: {sim:.4f}") |
| 377 | + print("\nTop 10 Tightest Embedding Spreads:") |
| 378 | + for label, mean_sim, std_sim in top_10_tightest: |
| 379 | + print(f" {label}: Mean Similarity = {mean_sim:.4f}, Std Dev = {std_sim:.4f}") |
| 380 | + print("\nTop 10 Most Diffuse Embedding Spreads:") |
| 381 | + for label, mean_sim, std_sim in top_10_diffuse: |
| 382 | + print(f" {label}: Mean Similarity = {mean_sim:.4f}, Std Dev = {std_sim:.4f}") |
| 383 | + |
| 384 | + |
| 385 | +""" |
| 386 | +Notes |
| 387 | + - Cosine Similarity across labels should be a broad distribution centered around 0 |
| 388 | + - Cosine Similarity within labels should be skewed towards 1, indicating that embeddings |
| 389 | + sharing the same label are more similar to each other |
| 390 | + |
| 391 | + - TODO: how do we test the within / across labels with bulking? |
| 392 | +""" |
0 commit comments