@@ -7,7 +7,7 @@ class DataSampler(object):
77 """DataSampler samples the conditional vector and corresponding data for CTGAN."""
88
99 def __init__ (self , data , output_info , log_frequency ):
10- self ._data = data
10+ self ._data_length = len ( data )
1111
1212 def is_discrete_column (column_info ):
1313 return (len (column_info ) == 1
@@ -115,33 +115,34 @@ def sample_original_condvec(self, batch):
115115 if self ._n_discrete_columns == 0 :
116116 return None
117117
118+ category_freq = self ._discrete_column_category_prob .flatten ()
119+ category_freq = category_freq [category_freq != 0 ]
120+ category_freq = category_freq / np .sum (category_freq )
121+ col_idxs = np .random .choice (np .arange (len (category_freq )), batch , p = category_freq )
118122 cond = np .zeros ((batch , self ._n_categories ), dtype = 'float32' )
119-
120- for i in range (batch ):
121- row_idx = np .random .randint (0 , len (self ._data ))
122- col_idx = np .random .randint (0 , self ._n_discrete_columns )
123- matrix_st = self ._discrete_column_matrix_st [col_idx ]
124- matrix_ed = matrix_st + self ._discrete_column_n_category [col_idx ]
125- pick = np .argmax (self ._data [row_idx , matrix_st :matrix_ed ])
126- cond [i , pick + self ._discrete_column_cond_st [col_idx ]] = 1
123+ cond [np .arange (batch ), col_idxs ] = 1
127124
128125 return cond
129126
130- def sample_data (self , n , col , opt ):
127+ def sample_data (self , data , n , col , opt ):
131128 """Sample data from original training data satisfying the sampled conditional vector.
132129
130+ Args:
131+ data:
132+ The training data.
133133 Returns:
134- n rows of matrix data.
134+ n:
135+ n rows of matrix data.
135136 """
136137 if col is None :
137- idx = np .random .randint (len (self . _data ), size = n )
138- return self . _data [idx ]
138+ idx = np .random .randint (len (data ), size = n )
139+ return data [idx ]
139140
140141 idx = []
141142 for c , o in zip (col , opt ):
142143 idx .append (np .random .choice (self ._rid_by_cat_cols [c ][o ]))
143144
144- return self . _data [idx ]
145+ return data [idx ]
145146
146147 def dim_cond_vec (self ):
147148 """Return the total number of categories."""
0 commit comments