Skip to content

Commit 1d73aff

Browse files
committed
only use top 1000 genes for VAE
1 parent 3f9debf commit 1d73aff

File tree

2 files changed

+128
-166
lines changed

2 files changed

+128
-166
lines changed

dsbook/unsupervised/VAEofCarcinomas.ipynb

Lines changed: 88 additions & 115 deletions
Large diffs are not rendered by default.

dsbook/unsupervised/VAEofCarcinomas.md

Lines changed: 40 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,29 @@ luad = tcga.get_expression_data(my_path + "../data/luad_tcga_pan_can_atlas_2018.
4040
lusc = tcga.get_expression_data(my_path + "../data/lusc_tcga_pan_can_atlas_2018.tar.gz", 'https://cbioportal-datahub.s3.amazonaws.com/lusc_tcga_pan_can_atlas_2018.tar.gz',"data_mrna_seq_v2_rsem.txt")
4141
```
4242

43-
We now merge the datasets, and ensure that we only include transcripts that are measured in all samples with counts greater than zero. Further we scale the measurements so that every gene expression value is scaled using scikit-learn's StandardScaler.
43+
We now merge the datasets, and ensure that we only include transcripts that are measured in all samples with counts greater than zero. Subsequently we log our data and reduce our set to the 1k transcripts with highest variance. Further we scale the measurements so that every gene expression value is scaled using scikit-learn's StandardScaler.
4444

4545
```{code-cell} ipython3
46-
:id: nTGtXhUgdZIw
47-
4846
from sklearn.preprocessing import StandardScaler
4947
scaler = StandardScaler()
5048
combined = pd.concat([lusc[lusc.index.notna()] , luad[luad.index.notna()]], axis=1, sort=False)
5149
# Drop rows with any missing values
5250
combined.dropna(axis=0, how='any', inplace=True)
5351
combined = combined.loc[~(combined<=0.0).any(axis=1)]
5452
combined.index = combined.index.astype(str)
55-
X=scaler.fit_transform(np.log2(combined).T).T
56-
combined = pd.DataFrame(data=X,index=combined.index,columns=combined.columns)
53+
log_combined = np.log2(combined)
54+
var = log_combined.var(axis=1)
55+
top_k = 1000
56+
top_genes = var.nlargest(top_k).index
57+
58+
log_combined = log_combined.loc[top_genes]
59+
scaler = StandardScaler()
60+
X = scaler.fit_transform(log_combined.T).astype(np.float32)
61+
combined_reduced = pd.DataFrame(
62+
data=X.T,
63+
index=log_combined.index,
64+
columns=log_combined.columns,
65+
)
5766
```
5867

5968

@@ -68,15 +77,17 @@ import numpy as np
6877
6978
# Setting training parameters
7079
batch_size, lr, epochs, log_interval = 256, 1e-3, 501, 100
71-
hidden_dim, latent_dim = 2048, 12
80+
hidden_dim, latent_dim = 512, 12
7281
7382
# Check if GPU is available
7483
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7584
torch.manual_seed(4711)
7685
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
7786
87+
7888
# Convert combined DataFrame to a PyTorch tensor
79-
datapoints = torch.tensor(combined.to_numpy().T, dtype=torch.float32)
89+
datapoints = torch.tensor(combined_reduced.to_numpy().T, dtype=torch.float32)
90+
input_dim = datapoints.shape[1]
8091
labels = torch.tensor([1.0 for _ in lusc.columns] + [0.0 for _ in luad.columns], dtype=torch.float32)
8192
8293
# Use TensorDataset to create a dataset
@@ -112,7 +123,7 @@ class VAE(nn.Module):
112123
113124
def decode(self, z):
114125
h3 = torch.relu(self.fc3(z))
115-
out = torch.sigmoid(self.fc4(h3))
126+
out = self.fc4(h3)
116127
return out
117128
118129
def forward(self, x):
@@ -124,51 +135,45 @@ class VAE(nn.Module):
124135
Next, we select a gradient-based optimizer (Adam) and the loss function to optimize (reconstruction + KLD). The train and test procedures are defined below.
125136

126137
```{code-cell} ipython3
127-
input_dim = combined.shape[0]
128138
model = VAE(input_dim, hidden_dim, latent_dim).to(device)
129139
optimizer = optim.Adam(model.parameters(), lr=lr)
130140
131141
132142
# Reconstruction + KL divergence losses summed over all elements and batch
133-
def loss_function(recon_x, x, mu, logvar):
134-
MSE = nn.functional.mse_loss(recon_x, x, reduction='sum')
135-
# see Appendix B from VAE paper:
136-
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
137-
# https://arxiv.org/abs/1312.6114
138-
# Calculating the Kullback–Leibler divergence
139-
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
140-
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
141-
#print(f"MSE={MSE}, KLD={KLD}")
142-
return MSE + KLD
143+
def loss_function(recon_x, x, mu, logvar, beta=1.0):
144+
# reconstruction per feature
145+
recon_loss = F.mse_loss(recon_x, x, reduction='mean')
146+
147+
# KL per sample, then mean
148+
kl_per_sample = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
149+
kld = kl_per_sample.mean()
150+
151+
return recon_loss + beta * kld, recon_loss.item(), kld.item()
143152
144153
def train(epoch):
145154
model.train()
146155
train_loss = 0
147156
for batch_idx, (data, _) in enumerate(train_loader):
148-
data = data.to(device)
157+
data = data.to(device, non_blocking=True)
149158
optimizer.zero_grad()
150159
recon_batch, mu, logvar = model(data)
151-
loss = loss_function(recon_batch, data, mu, logvar)
160+
loss, recon_loss, kld = loss_function(recon_batch, data, mu, logvar, beta=1.0)
152161
loss.backward()
153-
train_loss += loss.item()
154162
optimizer.step()
155-
if epoch % log_interval == 0:
156-
print('====> Epoch: {} Average loss: {:.4f}'.format(
157-
epoch, train_loss / len(train_loader.dataset)))
163+
train_loss += loss.item()
164+
if epoch % log_interval == 0:
165+
print('====> Epoch: {} Average loss: {:.4f}'.format(
166+
epoch, train_loss / len(train_loader.dataset)))
158167
159168
160169
def test(epoch):
161170
model.eval()
162-
test_loss = 0
163171
with torch.no_grad():
164-
for i, (data, _) in enumerate(test_loader):
165-
# for i, (data, _) in enumerate(train_loader):
166-
data = data.to(device)
167-
recon_batch, mu, logvar = model(data)
168-
test_loss += loss_function(recon_batch, data, mu, logvar).item()
169-
if epoch % log_interval == 0:
170-
test_loss /= len(test_loader.dataset)
171-
print('====> Test set loss: {:.4f}'.format(test_loss))
172+
X_tensor = datapoints.to(device)
173+
x_hat_, mu_, logvar_ = model(X_tensor)
174+
x_hat = x_hat_.cpu().numpy()
175+
z = mu_.cpu().numpy()
176+
std = torch.exp(0.5 * logvar_).cpu().numpy()
172177
```
173178

174179
Now we are set to run the procedure for 500 epochs.
@@ -288,18 +293,14 @@ Further, we can use the network to generate "typical" expression profiles. We ha
288293
```{code-cell} ipython3
289294
290295
z_fix = torch.tensor(np.concatenate(([means["LUSC"]],[means["LUAD"]]), axis=0))
291-
292296
z_fix = z_fix.to(device)
293297
x_fix = model.decode(z_fix).cpu().detach().numpy()
294-
predicted = pd.DataFrame(data=x_fix.T, index=combined.index, columns=["LUSC", "LUAD"])
298+
predicted = pd.DataFrame(data=x_fix.T, index=combined_reduced.index, columns=["LUSC", "LUAD"])
295299
```
296300

297301
Using these generated profiles we may, for instance, identify the genes most differentially expressed between the generated LUSC and LUAD profiles.
298302

299303
```{code-cell} ipython3
300-
:id: rdGnIpvgdZI1
301-
:outputId: 3fcf9449-bfea-443c-d2fc-179054e1c906
302-
303304
predicted["diff"] = predicted["LUSC"] - predicted["LUAD"]
304305
# predicted.sort_values(by='diff', ascending=False, inplace = True)
305306
```
@@ -312,20 +313,8 @@ The genes that the decoder finds most different between the set means can now be
312313
predicted["diff"].idxmin(axis=0)
313314
```
314315

315-
Which is a [cancer-related](https://www.proteinatlas.org/ENSG00000172731-LRRC20/cancer) protein.
316-
317-
+++
318-
319316
and then in the negative direction (larger in LUAD than LUSC).
320317

321318
```{code-cell} ipython3
322319
predicted["diff"].idxmax(axis=0)
323320
```
324-
325-
Which is a [prognostic marker](https://www.proteinatlas.org/ENSG00000146054-TRIM7/cancer) for survival in LUAD.
326-
327-
Here these two genes seem to be the largest differentiators between LUSC and LUAD. We can also note that, as with PCA, the gene KRT17 appears quite different between the cancer types:
328-
329-
```{code-cell} ipython3
330-
predicted.loc["KRT17"]
331-
```

0 commit comments

Comments
 (0)