@@ -32,19 +32,15 @@ class VAEArith:
3232 """
3333
3434 def __init__ (self , x_dimension , z_dimension = 100 , ** kwargs ):
35- tf .reset_default_graph ()
3635 self .x_dim = x_dimension
3736 self .z_dim = z_dimension
3837 self .learning_rate = kwargs .get ("learning_rate" , 0.001 )
3938 self .dropout_rate = kwargs .get ("dropout_rate" , 0.2 )
4039 self .model_to_use = kwargs .get ("model_path" , "./models/scgen" )
4140 self .alpha = kwargs .get ("alpha" , 0.00005 )
4241 self .is_training = tf .placeholder (tf .bool , name = 'training_flag' )
43- self .global_step = tf .Variable (0 , name = 'global_step' , trainable = False , dtype = tf .int32 )
4442 self .x = tf .placeholder (tf .float32 , shape = [None , self .x_dim ], name = "data" )
4543 self .z = tf .placeholder (tf .float32 , shape = [None , self .z_dim ], name = "latent" )
46- self .time_step = tf .placeholder (tf .int32 )
47- self .size = tf .placeholder (tf .int32 )
4844 self .init_w = tf .contrib .layers .xavier_initializer ()
4945 self ._create_network ()
5046 self ._loss_function ()
@@ -119,7 +115,8 @@ def _sample_z(self):
119115 # Returns
120116 The computed Tensor of samples with shape [size, z_dim].
121117 """
122- eps = tf .random_normal (shape = [self .size , self .z_dim ])
118+ batch_size = tf .shape (self .mu )[0 ]
119+ eps = tf .random_normal (shape = [batch_size , self .z_dim ])
123120 return self .mu + tf .exp (self .log_var / 2 ) * eps
124121
125122 def _create_network (self ):
@@ -174,7 +171,7 @@ def to_latent(self, data):
174171 latent: numpy nd-array
175172 Returns array containing latent space encoding of 'data'
176173 """
177- latent = self .sess .run (self .z_mean , feed_dict = {self .x : data , self .size : data . shape [ 0 ], self . is_training : False })
174+ latent = self .sess .run (self .z_mean , feed_dict = {self .x : data , self .is_training : False })
178175 return latent
179176
180177 def _avg_vector (self , data ):
@@ -429,8 +426,6 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
429426 """
430427 if initial_run :
431428 log .info ("----Training----" )
432- assign_step_zero = tf .assign (self .global_step , 0 )
433- _init_step = self .sess .run (assign_step_zero )
434429 if not initial_run :
435430 self .saver .restore (self .sess , self .model_to_use )
436431 if use_validation and valid_data is None :
@@ -442,9 +437,6 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
442437 min_delta = threshold
443438 patience_cnt = 0
444439 for it in range (n_epochs ):
445- increment_global_step_op = tf .assign (self .global_step , self .global_step + 1 )
446- _step = self .sess .run (increment_global_step_op )
447- current_step = self .sess .run (self .global_step )
448440 train_loss = 0.0
449441 for lower in range (0 , train_data .shape [0 ], batch_size ):
450442 upper = min (lower + batch_size , train_data .shape [0 ])
@@ -454,8 +446,7 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
454446 x_mb = train_data [lower :upper , :].X
455447 if upper - lower > 1 :
456448 _ , current_loss_train = self .sess .run ([self .solver , self .vae_loss ],
457- feed_dict = {self .x : x_mb , self .time_step : current_step ,
458- self .size : len (x_mb ), self .is_training : True })
449+ feed_dict = {self .x : x_mb , self .is_training : True })
459450 train_loss += current_loss_train
460451 if use_validation :
461452 valid_loss = 0
@@ -466,8 +457,7 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
466457 else :
467458 x_mb = valid_data [lower :upper , :].X
468459 current_loss_valid = self .sess .run (self .vae_loss ,
469- feed_dict = {self .x : x_mb , self .time_step : current_step ,
470- self .size : len (x_mb ), self .is_training : False })
460+ feed_dict = {self .x : x_mb , self .is_training : False })
471461 valid_loss += current_loss_valid
472462 loss_hist .append (valid_loss / valid_data .shape [0 ])
473463 if it > 0 and loss_hist [it - 1 ] - loss_hist [it ] > min_delta :
0 commit comments