Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 94 additions & 127 deletions chemCPA/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,76 +384,73 @@ def set_hparams_(self, seed, hparams):
self.hparams.update(hparams)
return self.hparams

def _lookup_drug_embeddings(self, drugs_idx):
if drugs_idx.ndim == 1:
drugs_idx = drugs_idx.unsqueeze(1)
return self.drug_embeddings.weight[drugs_idx]

def _transform_drug_embeddings(self, gathered_embeddings, dosages, drugs_idx):
batch_size, combo_drugs, _ = gathered_embeddings.shape

if self.doser_type == "mlp":
scaled_dosages_list = []
for i in range(combo_drugs):
dose_i = dosages[:, i].unsqueeze(1)
scaled_i = []
for b in range(batch_size):
d_idx = drugs_idx[b, i].item()
scaled_val = self.dosers[d_idx](dose_i[b]).sigmoid()
scaled_i.append(scaled_val)
scaled_i = torch.cat(scaled_i, dim=0).unsqueeze(1)
scaled_dosages_list.append(scaled_i)
scaled_dosages = torch.cat(scaled_dosages_list, dim=1)
elif self.doser_type == "amortized":
scaled_list = []
for i in range(combo_drugs):
emb_i = gathered_embeddings[:, i, :]
dose_i = dosages[:, i].unsqueeze(-1)
cat_i = torch.cat([emb_i, dose_i], dim=1)
scaled_i = self.dosers(cat_i).sigmoid()
scaled_list.append(scaled_i)
scaled_dosages = torch.stack(scaled_list, dim=1).squeeze(-1)
elif self.doser_type in ("sigm", "logsigm"):
scaled_list = []
for i in range(combo_drugs):
dose_i = dosages[:, i]
scaled_val = self.dosers(dose_i, drugs_idx[:, i]).unsqueeze(1)
scaled_list.append(scaled_val)
scaled_dosages = torch.cat(scaled_list, dim=1)
else:
scaled_dosages = dosages

if not self.enable_cpa_mode and self.drug_embedding_encoder is not None:
transformed_list = []
for i in range(combo_drugs):
emb_i = self.drug_embedding_encoder(gathered_embeddings[:, i, :])
transformed_list.append(emb_i.unsqueeze(1))
transformed = torch.cat(transformed_list, dim=1)
else:
transformed = gathered_embeddings

scaled_dosages_expanded = scaled_dosages.unsqueeze(-1)
scaled_embeddings = transformed * scaled_dosages_expanded
combo_embedding = scaled_embeddings.sum(dim=1)
return combo_embedding

def compute_drug_embeddings_(self, drugs=None, drugs_idx=None, dosages=None):
"""
Compute sum of drug embeddings, each scaled by a dose-response curve.
Computes the drug embedding.
For a drugs_idx input, it performs a lookup then applies doser scaling and transformation.
"""
assert (drugs is not None) or (drugs_idx is not None and dosages is not None), (
"Either `drugs` or (`drugs_idx` and `dosages`) must be provided."
)
if self.num_drugs <= 0:
return None

drugs, drugs_idx, dosages = _move_inputs(drugs, drugs_idx, dosages, device=self.device)
all_embeddings = self.drug_embeddings.weight # shape [num_drugs, embedding_dim]

if drugs_idx is not None:
# Ensure 2D
if drugs_idx.ndim == 1:
drugs_idx = drugs_idx.unsqueeze(1)
dosages = dosages.unsqueeze(1)
batch_size, combo_drugs = drugs_idx.shape

gathered_embeddings = all_embeddings[drugs_idx] # [batch_size, combo_drugs, emb_dim]

# scaled dosages
if self.doser_type == "mlp":
scaled_dosages_list = []
for i in range(combo_drugs):
idx_i = drugs_idx[:, i]
dose_i = dosages[:, i].unsqueeze(1)
scaled_i = []
for b in range(batch_size):
d_idx = idx_i[b].item()
scaled_i.append(self.dosers[d_idx](dose_i[b]).sigmoid())
scaled_i = torch.cat(scaled_i, dim=0)
scaled_dosages_list.append(scaled_i.unsqueeze(1))
scaled_dosages = torch.cat(scaled_dosages_list, dim=1)

elif self.doser_type == "amortized":
scaled_list = []
for i in range(combo_drugs):
emb_i = gathered_embeddings[:, i, :]
dose_i = dosages[:, i].unsqueeze(-1)
cat_i = torch.cat([emb_i, dose_i], dim=1)
scaled_i = self.dosers(cat_i).sigmoid()
scaled_list.append(scaled_i)
scaled_dosages = torch.stack(scaled_list, dim=1).squeeze(-1)

elif self.doser_type in ("sigm", "logsigm"):
scaled_list = []
for i in range(combo_drugs):
dose_i = dosages[:, i]
drug_i = drugs_idx[:, i]
scaled_list.append(self.dosers(dose_i, drug_i).unsqueeze(1))
scaled_dosages = torch.cat(scaled_list, dim=1)
else:
scaled_dosages = dosages

# transform each embedding if needed
if not self.enable_cpa_mode and self.drug_embedding_encoder is not None:
transformed_list = []
for i in range(combo_drugs):
emb_i = self.drug_embedding_encoder(gathered_embeddings[:, i, :])
transformed_list.append(emb_i.unsqueeze(1))
transformed = torch.cat(transformed_list, dim=1)
else:
transformed = gathered_embeddings

scaled_dosages_expanded = scaled_dosages.unsqueeze(-1)
scaled_embeddings = transformed * scaled_dosages_expanded
combo_embedding = scaled_embeddings.sum(dim=1)
return combo_embedding
gathered = self._lookup_drug_embeddings(drugs_idx)
return self._transform_drug_embeddings(gathered, dosages, drugs_idx)
else:
# (drugs) => shape [batch_size, num_drugs]
all_embeddings = self.drug_embeddings.weight
if self.doser_type == "mlp":
scaled_list = []
for d in range(self.num_drugs):
Expand All @@ -472,55 +469,48 @@ def compute_drug_embeddings_(self, drugs=None, drugs_idx=None, dosages=None):
transformed_embeddings = self.drug_embedding_encoder(all_embeddings)
else:
transformed_embeddings = all_embeddings

drug_combo_emb = scaled_dosages @ transformed_embeddings
return drug_combo_emb

def predict(
self,
genes,
drugs=None,
drugs_idx=None,
dosages=None,
covariates=None,
return_latent_basal=False,
):
"""
Predict how gene expression in `genes` changes when treated with `drugs`.
"""
assert (drugs is not None) or (drugs_idx is not None and dosages is not None)
return scaled_dosages @ transformed_embeddings

def compute_covariate_embeddings_(self, covariates):
if covariates is None or self.num_covariates[0] == 0:
return []
cov_embeddings = []
for i, emb_cov in enumerate(self.covariates_embeddings):
emb_cov = emb_cov.to(self.device)
cov_idx = covariates[i].argmax(dim=1)
cov_embeddings.append(emb_cov(cov_idx))
return cov_embeddings

def predict(self, genes, drugs=None, drugs_idx=None, dosages=None, covariates=None, return_latent_basal=False):
genes, drugs, drugs_idx, dosages, covariates = _move_inputs(
genes, drugs, drugs_idx, dosages, covariates, device=self.device
)
latent_basal = self.encoder(genes)
latent_treated = latent_basal

if self.num_drugs > 0:
drug_embedding = self.compute_drug_embeddings_(drugs=drugs, drugs_idx=drugs_idx, dosages=dosages)
drug_embedding = self.compute_drug_embeddings_(drugs=drugs, drugs_idx=drugs_idx, dosages=dosages)
cov_embeddings = self.compute_covariate_embeddings_(covariates)
return self.compute_prediction(genes, drug_embedding, cov_embeddings, return_latent_basal)

def compute_prediction(self, genes, drug_embedding=None, cov_embeddings=None, return_latent_basal=False):
latent = self.encoder(genes)
latent_with_cov = latent
if cov_embeddings:
for cov_emb in cov_embeddings:
latent_with_cov = latent_with_cov + cov_emb
latent_treated = latent_with_cov
if drug_embedding is not None:
latent_treated = latent_treated + drug_embedding

if self.num_covariates[0] > 0:
for cov_type, emb_cov in enumerate(self.covariates_embeddings):
emb_cov = emb_cov.to(self.device)
cov_idx = covariates[cov_type].argmax(1)
latent_treated = latent_treated + emb_cov(cov_idx)

# Construct cell_drug_embedding for e.g. multi-task or logging
if self.num_covariates[0] > 0 and self.num_drugs > 0:
cell_drug_embedding = torch.cat([emb_cov(cov_idx), drug_embedding], dim=1)
elif self.num_drugs > 0:
cell_drug_embedding = drug_embedding
else:
cell_drug_embedding = torch.zeros_like(latent_basal)

gene_reconstructions = self.decoder(latent_treated)
dim = gene_reconstructions.size(1) // 2
mean = gene_reconstructions[:, :dim]
var = F.softplus(gene_reconstructions[:, dim:])
normalized_reconstructions = torch.cat([mean, var], dim=1)

if cov_embeddings and drug_embedding is not None:
cell_drug_embedding = torch.cat([cov_embeddings[-1], drug_embedding], dim=1)
elif drug_embedding is not None:
cell_drug_embedding = drug_embedding
else:
cell_drug_embedding = torch.zeros_like(latent)
if return_latent_basal:
return normalized_reconstructions, cell_drug_embedding, (latent_basal, drug_embedding, latent_treated)
return normalized_reconstructions, cell_drug_embedding, (latent, drug_embedding, latent_treated)
return normalized_reconstructions, cell_drug_embedding

def early_stopping(self, score):
Expand All @@ -533,21 +523,8 @@ def early_stopping(self, score):
self.patience_trials += 1
return self.patience_trials > self.patience

def update(
self,
genes,
drugs=None,
drugs_idx=None,
dosages=None,
degs=None,
covariates=None,
):
"""
(Optional) if you manually call a training step here; typically unused under Lightning.
"""
def update(self, genes, drugs=None, drugs_idx=None, dosages=None, degs=None, covariates=None):
assert (drugs is not None) or (drugs_idx is not None and dosages is not None)

# ---- Forward pass (with debugging) ----
gene_reconstructions, cell_drug_embedding, (latent_basal, drug_embedding, latent_treated) = self.predict(
genes=genes,
drugs=drugs,
Expand All @@ -559,8 +536,6 @@ def update(
dim = gene_reconstructions.size(1) // 2
mean = gene_reconstructions[:, :dim]
var = gene_reconstructions[:, dim:]

# Debug check for NaNs
if torch.isnan(mean).any() or torch.isnan(var).any():
print(
f"NaN detected in mean/var:\n"
Expand All @@ -570,17 +545,10 @@ def update(
f" mean[:5] = {mean[:5]}\n"
f" var[:5] = {var[:5]}"
)

# ---- Reconstruction loss ----
reconstruction_loss = self.loss_autoencoder(input=mean, target=genes, var=var)

# ---- Drug adversary loss (if used) ----
adversary_drugs_loss = torch.tensor([0.0], device=self.device)
if self.num_drugs > 0:
adversary_drugs_predictions = self.adversary_drugs(latent_basal)
# ...BCEWithLogitsLoss if multi-label...

# ---- Covariates adversary loss (if used) ----
adversary_covariates_loss = torch.tensor([0.0], device=self.device)
if self.num_covariates[0] > 0:
adversary_covariate_predictions = []
Expand All @@ -592,9 +560,7 @@ def update(
pred,
covariates[i].argmax(1),
)

self.iteration += 1

return {
"loss_reconstruction": reconstruction_loss.item(),
"loss_adv_drugs": adversary_drugs_loss.item(),
Expand All @@ -605,3 +571,4 @@ def update(
def defaults(cls):
return cls.set_hparams_(cls, 0, "")


Loading