@@ -98,7 +98,7 @@ def _encoder(self):
9898 log_var = Dense (self .z_dim , kernel_initializer = self .init_w )(h )
9999 z = Lambda (self ._sample_z , output_shape = (self .z_dim ,), name = "Z" )([mean , log_var ])
100100
101- self .encoder_model = Model (inputs = self .x , outputs = [ mean , log_var , z ] , name = "encoder" )
101+ self .encoder_model = Model (inputs = self .x , outputs = z , name = "encoder" )
102102 return mean , log_var
103103
104104 def _decoder (self ):
@@ -178,7 +178,7 @@ def _create_network(self):
178178 self .mu , self .log_var = self ._encoder ()
179179
180180 self .x_hat = self ._decoder ()
181- self .vae_model = Model (inputs = self .x , outputs = self .decoder_model (self .encoder_model (self .x )[ 2 ] ), name = "VAE" )
181+ self .vae_model = Model (inputs = self .x , outputs = self .decoder_model (self .encoder_model (self .x )), name = "VAE" )
182182
183183 def _loss_function (self ):
184184 """
@@ -224,7 +224,7 @@ def to_latent(self, data):
224224 latent: numpy nd-array
225225 Returns array containing latent space encoding of 'data'
226226 """
227- latent = self .encoder_model .predict (data )[ 2 ]
227+ latent = self .encoder_model .predict (data )
228228 return latent
229229
230230 def _avg_vector (self , data ):
@@ -246,7 +246,7 @@ def _avg_vector(self, data):
246246 latent_avg = numpy .average (latent , axis = 0 )
247247 return latent_avg
248248
249- def reconstruct (self , data , use_data = False ):
249+ def reconstruct (self , data ):
250250 """
251251 Map back the latent space encoding via the decoder.
252252
@@ -265,11 +265,6 @@ def reconstruct(self, data, use_data=False):
265265 rec_data: 'numpy nd-array'
266266 Returns 'numpy nd-array` containing reconstructed 'data' in shape [n_obs, n_vars].
267267 """
268- # if use_data:
269- # latent = data
270- # else:
271- # latent = self.to_latent(data)
272- # rec_data = self.sess.run(self.x_hat, feed_dict={self.z_mean: latent, self.is_training: False})
273268 rec_data = self .decoder_model .predict (x = data )
274269 return rec_data
275270
@@ -321,7 +316,7 @@ def linear_interpolation(self, source_adata, dest_adata, n_steps):
321316 vector = start * (1 - alpha ) + end * alpha
322317 vectors [i , :] = vector
323318 vectors = numpy .array (vectors )
324- interpolation = self .reconstruct (vectors , use_data = True )
319+ interpolation = self .reconstruct (vectors )
325320 return interpolation
326321
327322 def predict (self , adata , conditions , cell_type_key , condition_key , adata_to_predict = None , celltype_to_predict = None , obs_key = "all" ):
@@ -393,7 +388,7 @@ def predict(self, adata, conditions, cell_type_key, condition_key, adata_to_pred
393388 else :
394389 latent_cd = self .to_latent (ctrl_pred .X )
395390 stim_pred = delta + latent_cd
396- predicted_cells = self .reconstruct (stim_pred , use_data = True )
391+ predicted_cells = self .reconstruct (stim_pred )
397392 return predicted_cells , delta
398393
399394 def restore_model (self ):
0 commit comments