diff --git a/chemCPA/model.py b/chemCPA/model.py index d1b3751..0e400b1 100644 --- a/chemCPA/model.py +++ b/chemCPA/model.py @@ -384,76 +384,73 @@ def set_hparams_(self, seed, hparams): self.hparams.update(hparams) return self.hparams + def _lookup_drug_embeddings(self, drugs_idx): + if drugs_idx.ndim == 1: + drugs_idx = drugs_idx.unsqueeze(1) + return self.drug_embeddings.weight[drugs_idx] + + def _transform_drug_embeddings(self, gathered_embeddings, dosages, drugs_idx): + batch_size, combo_drugs, _ = gathered_embeddings.shape + + if self.doser_type == "mlp": + scaled_dosages_list = [] + for i in range(combo_drugs): + dose_i = dosages[:, i].unsqueeze(1) + scaled_i = [] + for b in range(batch_size): + d_idx = drugs_idx[b, i].item() + scaled_val = self.dosers[d_idx](dose_i[b]).sigmoid() + scaled_i.append(scaled_val) + scaled_i = torch.cat(scaled_i, dim=0).unsqueeze(1) + scaled_dosages_list.append(scaled_i) + scaled_dosages = torch.cat(scaled_dosages_list, dim=1) + elif self.doser_type == "amortized": + scaled_list = [] + for i in range(combo_drugs): + emb_i = gathered_embeddings[:, i, :] + dose_i = dosages[:, i].unsqueeze(-1) + cat_i = torch.cat([emb_i, dose_i], dim=1) + scaled_i = self.dosers(cat_i).sigmoid() + scaled_list.append(scaled_i) + scaled_dosages = torch.stack(scaled_list, dim=1).squeeze(-1) + elif self.doser_type in ("sigm", "logsigm"): + scaled_list = [] + for i in range(combo_drugs): + dose_i = dosages[:, i] + scaled_val = self.dosers(dose_i, drugs_idx[:, i]).unsqueeze(1) + scaled_list.append(scaled_val) + scaled_dosages = torch.cat(scaled_list, dim=1) + else: + scaled_dosages = dosages + + if not self.enable_cpa_mode and self.drug_embedding_encoder is not None: + transformed_list = [] + for i in range(combo_drugs): + emb_i = self.drug_embedding_encoder(gathered_embeddings[:, i, :]) + transformed_list.append(emb_i.unsqueeze(1)) + transformed = torch.cat(transformed_list, dim=1) + else: + transformed = gathered_embeddings + + scaled_dosages_expanded = scaled_dosages.unsqueeze(-1) + scaled_embeddings = transformed * scaled_dosages_expanded + combo_embedding = scaled_embeddings.sum(dim=1) + return combo_embedding + def compute_drug_embeddings_(self, drugs=None, drugs_idx=None, dosages=None): """ - Compute sum of drug embeddings, each scaled by a dose-response curve. + Computes the drug embedding. + For a drugs_idx input, it performs a lookup then applies doser scaling and transformation. """ - assert (drugs is not None) or (drugs_idx is not None and dosages is not None), ( - "Either `drugs` or (`drugs_idx` and `dosages`) must be provided." - ) + if self.num_drugs <= 0: + return None drugs, drugs_idx, dosages = _move_inputs(drugs, drugs_idx, dosages, device=self.device) - all_embeddings = self.drug_embeddings.weight # shape [num_drugs, embedding_dim] - if drugs_idx is not None: - # Ensure 2D - if drugs_idx.ndim == 1: - drugs_idx = drugs_idx.unsqueeze(1) - dosages = dosages.unsqueeze(1) - batch_size, combo_drugs = drugs_idx.shape - - gathered_embeddings = all_embeddings[drugs_idx] # [batch_size, combo_drugs, emb_dim] - - # scaled dosages - if self.doser_type == "mlp": - scaled_dosages_list = [] - for i in range(combo_drugs): - idx_i = drugs_idx[:, i] - dose_i = dosages[:, i].unsqueeze(1) - scaled_i = [] - for b in range(batch_size): - d_idx = idx_i[b].item() - scaled_i.append(self.dosers[d_idx](dose_i[b]).sigmoid()) - scaled_i = torch.cat(scaled_i, dim=0) - scaled_dosages_list.append(scaled_i.unsqueeze(1)) - scaled_dosages = torch.cat(scaled_dosages_list, dim=1) - - elif self.doser_type == "amortized": - scaled_list = [] - for i in range(combo_drugs): - emb_i = gathered_embeddings[:, i, :] - dose_i = dosages[:, i].unsqueeze(-1) - cat_i = torch.cat([emb_i, dose_i], dim=1) - scaled_i = self.dosers(cat_i).sigmoid() - scaled_list.append(scaled_i) - scaled_dosages = torch.stack(scaled_list, dim=1).squeeze(-1) - - elif self.doser_type in ("sigm", "logsigm"): - scaled_list = [] - for i in range(combo_drugs): - dose_i = dosages[:, i] - drug_i = drugs_idx[:, i] - scaled_list.append(self.dosers(dose_i, drug_i).unsqueeze(1)) - scaled_dosages = torch.cat(scaled_list, dim=1) - else: - scaled_dosages = dosages - - # transform each embedding if needed - if not self.enable_cpa_mode and self.drug_embedding_encoder is not None: - transformed_list = [] - for i in range(combo_drugs): - emb_i = self.drug_embedding_encoder(gathered_embeddings[:, i, :]) - transformed_list.append(emb_i.unsqueeze(1)) - transformed = torch.cat(transformed_list, dim=1) - else: - transformed = gathered_embeddings - - scaled_dosages_expanded = scaled_dosages.unsqueeze(-1) - scaled_embeddings = transformed * scaled_dosages_expanded - combo_embedding = scaled_embeddings.sum(dim=1) - return combo_embedding + gathered = self._lookup_drug_embeddings(drugs_idx) + return self._transform_drug_embeddings(gathered, dosages, drugs_idx) else: - # (drugs) => shape [batch_size, num_drugs] + all_embeddings = self.drug_embeddings.weight if self.doser_type == "mlp": scaled_list = [] for d in range(self.num_drugs): @@ -472,55 +469,48 @@ def compute_drug_embeddings_(self, drugs=None, drugs_idx=None, dosages=None): transformed_embeddings = self.drug_embedding_encoder(all_embeddings) else: transformed_embeddings = all_embeddings - - drug_combo_emb = scaled_dosages @ transformed_embeddings - return drug_combo_emb - - def predict( - self, - genes, - drugs=None, - drugs_idx=None, - dosages=None, - covariates=None, - return_latent_basal=False, - ): - """ - Predict how gene expression in `genes` changes when treated with `drugs`. - """ - assert (drugs is not None) or (drugs_idx is not None and dosages is not None) + return scaled_dosages @ transformed_embeddings + + def compute_covariate_embeddings_(self, covariates): + if covariates is None or self.num_covariates[0] == 0: + return [] + cov_embeddings = [] + for i, emb_cov in enumerate(self.covariates_embeddings): + emb_cov = emb_cov.to(self.device) + cov_idx = covariates[i].argmax(dim=1) + cov_embeddings.append(emb_cov(cov_idx)) + return cov_embeddings + + def predict(self, genes, drugs=None, drugs_idx=None, dosages=None, covariates=None, return_latent_basal=False): genes, drugs, drugs_idx, dosages, covariates = _move_inputs( genes, drugs, drugs_idx, dosages, covariates, device=self.device ) - latent_basal = self.encoder(genes) - latent_treated = latent_basal - - if self.num_drugs > 0: - drug_embedding = self.compute_drug_embeddings_(drugs=drugs, drugs_idx=drugs_idx, dosages=dosages) + drug_embedding = self.compute_drug_embeddings_(drugs=drugs, drugs_idx=drugs_idx, dosages=dosages) + cov_embeddings = self.compute_covariate_embeddings_(covariates) + return self.compute_prediction(genes, drug_embedding, cov_embeddings, return_latent_basal) + + def compute_prediction(self, genes, drug_embedding=None, cov_embeddings=None, return_latent_basal=False): + latent = self.encoder(genes) + latent_with_cov = latent + if cov_embeddings: + for cov_emb in cov_embeddings: + latent_with_cov = latent_with_cov + cov_emb + latent_treated = latent_with_cov + if drug_embedding is not None: latent_treated = latent_treated + drug_embedding - - if self.num_covariates[0] > 0: - for cov_type, emb_cov in enumerate(self.covariates_embeddings): - emb_cov = emb_cov.to(self.device) - cov_idx = covariates[cov_type].argmax(1) - latent_treated = latent_treated + emb_cov(cov_idx) - - # Construct cell_drug_embedding for e.g. multi-task or logging - if self.num_covariates[0] > 0 and self.num_drugs > 0: - cell_drug_embedding = torch.cat([emb_cov(cov_idx), drug_embedding], dim=1) - elif self.num_drugs > 0: - cell_drug_embedding = drug_embedding - else: - cell_drug_embedding = torch.zeros_like(latent_basal) - gene_reconstructions = self.decoder(latent_treated) dim = gene_reconstructions.size(1) // 2 mean = gene_reconstructions[:, :dim] var = F.softplus(gene_reconstructions[:, dim:]) normalized_reconstructions = torch.cat([mean, var], dim=1) - + if cov_embeddings and drug_embedding is not None: + cell_drug_embedding = torch.cat([cov_embeddings[-1], drug_embedding], dim=1) + elif drug_embedding is not None: + cell_drug_embedding = drug_embedding + else: + cell_drug_embedding = torch.zeros_like(latent) if return_latent_basal: - return normalized_reconstructions, cell_drug_embedding, (latent_basal, drug_embedding, latent_treated) + return normalized_reconstructions, cell_drug_embedding, (latent, drug_embedding, latent_treated) return normalized_reconstructions, cell_drug_embedding def early_stopping(self, score): @@ -533,21 +523,8 @@ def early_stopping(self, score): self.patience_trials += 1 return self.patience_trials > self.patience - def update( - self, - genes, - drugs=None, - drugs_idx=None, - dosages=None, - degs=None, - covariates=None, - ): - """ - (Optional) if you manually call a training step here; typically unused under Lightning. - """ + def update(self, genes, drugs=None, drugs_idx=None, dosages=None, degs=None, covariates=None): assert (drugs is not None) or (drugs_idx is not None and dosages is not None) - - # ---- Forward pass (with debugging) ---- gene_reconstructions, cell_drug_embedding, (latent_basal, drug_embedding, latent_treated) = self.predict( genes=genes, drugs=drugs, @@ -559,8 +536,6 @@ def update( dim = gene_reconstructions.size(1) // 2 mean = gene_reconstructions[:, :dim] var = gene_reconstructions[:, dim:] - - # Debug check for NaNs if torch.isnan(mean).any() or torch.isnan(var).any(): print( f"NaN detected in mean/var:\n" @@ -570,17 +545,10 @@ def update( f" mean[:5] = {mean[:5]}\n" f" var[:5] = {var[:5]}" ) - - # ---- Reconstruction loss ---- reconstruction_loss = self.loss_autoencoder(input=mean, target=genes, var=var) - - # ---- Drug adversary loss (if used) ---- adversary_drugs_loss = torch.tensor([0.0], device=self.device) if self.num_drugs > 0: adversary_drugs_predictions = self.adversary_drugs(latent_basal) - # ...BCEWithLogitsLoss if multi-label... - - # ---- Covariates adversary loss (if used) ---- adversary_covariates_loss = torch.tensor([0.0], device=self.device) if self.num_covariates[0] > 0: adversary_covariate_predictions = [] @@ -592,9 +560,7 @@ def update( pred, covariates[i].argmax(1), ) - self.iteration += 1 - return { "loss_reconstruction": reconstruction_loss.item(), "loss_adv_drugs": adversary_drugs_loss.item(), @@ -605,3 +571,4 @@ def update( def defaults(cls): return cls.set_hparams_(cls, 0, "") + diff --git a/chemCPA/predict.py b/chemCPA/predict.py new file mode 100644 index 0000000..7dab225 --- /dev/null +++ b/chemCPA/predict.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +import os +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import re +import torch +import scanpy as sc +import numpy as np +import pandas as pd +from lightning_module import ChemCPA +from embeddings.rdkit.embedding_rdkit import embed_and_save_embeddings +from tqdm import tqdm +import random + +# CONSTANTS +BATCH_SIZE = 4096 +DATASET_FILE = "/home/user/test/chemCPA/project_folder/datasets/lincs_full_smiles_sciplex_genes.h5ad" +CKPT_RELATIVE_PATH = Path("training_output", "training_output", "lincs_pretrain", "epoch=199-step=819900.ckpt") +SMILES_KEY = "canonical_smiles" +EMBEDDINGS_OUTPUT_PATH = "/tmp/10_random_rdkit_embeddings.parquet" +EMBEDDING_THRESHOLD = 0.01 +SKIP_VARIANCE_FILTER = True +DROP_COLS = ['latent_90', 'latent_103', 'latent_152', 'latent_164', + 'latent_176', 'latent_187', 'latent_196'] + +# Cache file paths +CONTROL_CACHE_PATH = "cached_control_genes.pt" +DRUG_EMBEDDINGS_CACHE_PATH = "cached_random_drug_embeddings.pt" + +def remove_non_alphanumeric(s): + return re.sub(r'[^a-zA-Z0-9]', '', s) + +def load_control_genes(file_path, device): + if os.path.exists(CONTROL_CACHE_PATH): + genes_tensor = torch.load(CONTROL_CACHE_PATH, map_location=device) + print("Loading cached control genes from", CONTROL_CACHE_PATH) + print("Control genes shape:", genes_tensor.shape) + return genes_tensor + adata = sc.read(file_path) + if 'condition' not in adata.obs.columns: + if 'pert_iname' in adata.obs.columns: + adata.obs['condition'] = adata.obs['pert_iname'].apply(remove_non_alphanumeric) + else: + raise KeyError("Neither 'condition' nor 'pert_iname' found in adata.obs") + control_adata = adata[adata.obs['condition'] == 'DMSO'] + if hasattr(control_adata.X, "toarray"): + genes_tensor = torch.tensor(control_adata.X.toarray(), dtype=torch.float32, device=device) + else: + genes_tensor = torch.tensor(control_adata.X, dtype=torch.float32, device=device) + torch.save(genes_tensor, CONTROL_CACHE_PATH) + print("Saved cached control genes to", CONTROL_CACHE_PATH) + print("Control genes shape:", genes_tensor.shape) + return genes_tensor + +def load_random_drug_embeddings(device): + if os.path.exists(DRUG_EMBEDDINGS_CACHE_PATH): + drug_tensor = torch.load(DRUG_EMBEDDINGS_CACHE_PATH, map_location=device) + print("Loading cached random drug embeddings from", DRUG_EMBEDDINGS_CACHE_PATH) + print("Drug embeddings shape:", drug_tensor.shape) + return drug_tensor + adata = sc.read(DATASET_FILE) + if SMILES_KEY not in adata.obs.columns: + raise KeyError(f"SMILES key '{SMILES_KEY}' not found in dataset!") + all_smiles = adata.obs[SMILES_KEY].dropna().tolist() + if len(all_smiles) < 10: + raise ValueError("Dataset contains fewer than 10 SMILES strings!") + random_smiles = list(np.random.choice(all_smiles, size=10, replace=False)) + print("Selected 10 random SMILES:") + for s in random_smiles: + print(s) + embed_and_save_embeddings( + random_smiles, + threshold=EMBEDDING_THRESHOLD, + embedding_path=EMBEDDINGS_OUTPUT_PATH, + skip_variance_filter=SKIP_VARIANCE_FILTER + ) + embeddings_df = pd.read_parquet(EMBEDDINGS_OUTPUT_PATH) + present_drop_cols = [col for col in DROP_COLS if col in embeddings_df.columns] + if present_drop_cols: + embeddings_df = embeddings_df.drop(columns=present_drop_cols) + print("Computed RDKit embeddings (after manual dropping):") + print(embeddings_df) + print("Final shape:", embeddings_df.shape) + drug_tensor = torch.tensor(embeddings_df.values, dtype=torch.float32, device=device) + torch.save(drug_tensor, DRUG_EMBEDDINGS_CACHE_PATH) + print("Saved cached drug embeddings to", DRUG_EMBEDDINGS_CACHE_PATH) + print("Drug embeddings shape:", drug_tensor.shape) + return drug_tensor + +def batched_predict_flat(model, genes, drug_embedding, cov_tensor, batch_size): + total = genes.shape[0] + preds_list = [] + cell_drug_list = [] + for start in tqdm(range(0, total, batch_size), desc="Batched prediction", unit="batch"): + end = min(start + batch_size, total) + batch_genes = genes[start:end] + batch_drug = drug_embedding[start:end] + batch_cov = cov_tensor[start:end] + if start == 0: + print("Inside batched_predict_flat, first batch shapes:") + print(" batch_genes:", batch_genes.shape) + print(" batch_drug :", batch_drug.shape) + print(" batch_cov :", batch_cov.shape) + with torch.no_grad(): + preds, cell_drug_emb = model.model.compute_prediction(batch_genes, batch_drug, [batch_cov]) + preds_list.append(preds.cpu()) + cell_drug_list.append(cell_drug_emb.cpu()) + return torch.cat(preds_list, dim=0), torch.cat(cell_drug_list, dim=0) + +def predict(model_path, drug_embeddings, covariate_ids, gene_control_data, device=None, use_pairs=True, batch_size=4096): + """ + Args: + model_path: path to the model checkpoint. + drug_embeddings: Tensor of shape [n_drugs, emb_dim] (e.g. 10 x 193). + covariate_ids: List of ints, one per covariate dimension. + gene_control_data: Tensor of shape [n_cells, gene_dim]. + device: torch.device; if None, it is derived from the model. + use_pairs: whether to include pairs of drugs in addition to singles. + batch_size: batch size for prediction. + Returns: + preds: predictions, Tensor of shape [n_cells, n_combos, out_dim]. + cell_drug_emb: cell-drug embeddings, same shape as preds. + drug_combos: a list of drug combinations (each is a list of drug indices). + """ + model = ChemCPA.load_from_checkpoint(str(model_path)) + model.eval() + if device is None: + device = next(model.model.parameters()).device + gene_control_data = gene_control_data.to(device) + drug_embeddings = drug_embeddings.to(device) + + # Prepare drug combinations (singles and, if use_pairs True, pairs) + n_drugs, emb_dim = drug_embeddings.shape + drug_combos = [[i] for i in range(n_drugs)] + if use_pairs: + for i in range(n_drugs): + for j in range(i+1, n_drugs): + drug_combos.append([i, j]) + n_combos = len(drug_combos) + max_combo = max(len(combo) for combo in drug_combos) + combo_emb_input = torch.zeros((n_combos, max_combo, emb_dim), device=device) + dosages = torch.zeros((n_combos, max_combo), device=device) + drugs_idx = torch.zeros((n_combos, max_combo), dtype=torch.long, device=device) + for idx, combo in enumerate(drug_combos): + for pos, d_idx in enumerate(combo): + combo_emb_input[idx, pos] = drug_embeddings[d_idx] + dosages[idx, pos] = 1.0 + drugs_idx[idx, pos] = d_idx + + print("Drug embeddings shape:", drug_embeddings.shape) + print("Drug embeddings with combos shape:", combo_emb_input.sum(dim=1).shape) + + with torch.no_grad(): + drug_combo_transformed = model.model._transform_drug_embeddings(combo_emb_input, dosages, drugs_idx) + latent_dim = drug_combo_transformed.shape[-1] + print("Transformed drug combo embeddings shape:", drug_combo_transformed.shape) + + # Process covariate_ids into one-hot vectors. + num_covariates = model.model.num_covariates + cov_config = [] + for cid, n in zip(covariate_ids, num_covariates): + one_hot = torch.zeros((1, n), device=device) + one_hot[0, cid] = 1.0 + cov_config.append(one_hot) + with torch.no_grad(): + cov_emb_list = model.model.compute_covariate_embeddings_(cov_config) + combined_cov = sum(cov_emb_list) + if combined_cov.dim() != 2 or combined_cov.size(0) != 1: + combined_cov = combined_cov.unsqueeze(0) + print("Combined covariate shape:", combined_cov.shape) # expected [1, latent_dim] + + # Expand inputs: + n_cells, gene_dim = gene_control_data.shape + genes_expanded = gene_control_data.unsqueeze(1).expand(n_cells, n_combos, gene_dim) + drug_expanded = drug_combo_transformed.unsqueeze(0).expand(n_cells, n_combos, latent_dim) + cov_expanded = combined_cov.unsqueeze(0).unsqueeze(0).expand(n_cells, n_combos, latent_dim) + + total = n_cells * n_combos + print("n_cells =", n_cells, "n_combos =", n_combos) + print("Total combined predictions to compute:", total) + genes_flat = genes_expanded.reshape(total, gene_dim) + drug_flat = drug_expanded.reshape(total, latent_dim) + cov_flat = cov_expanded.reshape(total, latent_dim) + print("genes_flat shape:", genes_flat.shape) + print("drug_flat shape:", drug_flat.shape) + print("cov_flat shape:", cov_flat.shape) + + preds_all_flat, cell_drug_emb_all_flat = batched_predict_flat(model, genes_flat, drug_flat, cov_flat, batch_size) + out_dim = preds_all_flat.shape[-1] + preds = preds_all_flat.reshape(n_cells, n_combos, out_dim) + cell_drug_emb = cell_drug_emb_all_flat.reshape(n_cells, n_combos, out_dim) + return preds, cell_drug_emb, drug_combos + +def prepare(): + """ + Prepares the necessary inputs for prediction. + Returns: + gene_control_data: Tensor with gene control data. + drug_embeddings: Tensor with drug embeddings. + covariate_ids: List of ints (one per covariate dimension). + model_path: Path to the model checkpoint. + device: torch.device. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + gene_control_data = load_control_genes(DATASET_FILE, device) + drug_embeddings = load_random_drug_embeddings(device) + # For this example, assume the model has two covariate dimensions; use 0 for both. + covariate_ids = [0, 0] + model_path = Path(__file__).resolve().parent.parent / CKPT_RELATIVE_PATH + return gene_control_data, drug_embeddings, covariate_ids, model_path, device + +def main(): + gene_control_data, drug_embeddings, covariate_ids, model_path, device = prepare() + preds, cell_drug_emb, drug_combos = predict(model_path, drug_embeddings, covariate_ids, gene_control_data, device=device) + print("Final predictions shape:", preds.shape) + print("Final cell-drug embedding shape:", cell_drug_emb.shape) + print("Drug combos (singles and pairs):") + for combo in drug_combos: + print(combo) + +if __name__ == "__main__": + main()