11import warnings
22
33import numpy as np
4+ import pandas as pd
45import torch
56from packaging import version
67from torch import optim
1314
1415class Discriminator (Module ):
1516
16- def __init__ (self , input_dim , dis_dims , pack = 10 ):
17+ def __init__ (self , input_dim , discriminator_dim , pack = 10 ):
1718 super (Discriminator , self ).__init__ ()
1819 dim = input_dim * pack
1920 self .pack = pack
2021 self .packdim = dim
2122 seq = []
22- for item in list (dis_dims ):
23+ for item in list (discriminator_dim ):
2324 seq += [Linear (dim , item ), LeakyReLU (0.2 ), Dropout (0.5 )]
2425 dim = item
2526
@@ -222,6 +223,31 @@ def _cond_loss(self, data, c, m):
222223
223224 return (loss * m ).sum () / data .size ()[0 ]
224225
226+ def _validate_discrete_columns (self , train_data , discrete_columns ):
227+ """Check whether ``discrete_columns`` exists in ``train_data``.
228+
229+ Args:
230+ train_data (numpy.ndarray or pandas.DataFrame):
231+ Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
232+ discrete_columns (list-like):
233+ List of discrete columns to be used to generate the Conditional
234+ Vector. If ``train_data`` is a Numpy array, this list should
235+ contain the integer indices of the columns. Otherwise, if it is
236+ a ``pandas.DataFrame``, this list should contain the column names.
237+ """
238+ if isinstance (train_data , pd .DataFrame ):
239+ invalid_columns = set (discrete_columns ) - set (train_data .columns )
240+ elif isinstance (train_data , np .ndarray ):
241+ invalid_columns = []
242+ for column in discrete_columns :
243+ if column < 0 or column >= train_data .shape [1 ]:
244+ invalid_columns .append (column )
245+ else :
246+ raise TypeError ('``train_data`` should be either pd.DataFrame or np.array.' )
247+
248+ if invalid_columns :
249+ raise ValueError ('Invalid columns found: {}' .format (invalid_columns ))
250+
225251 def fit (self , train_data , discrete_columns = tuple (), epochs = None ):
226252 """Fit the CTGAN Synthesizer models to the training data.
227253
@@ -234,6 +260,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
234260 contain the integer indices of the columns. Otherwise, if it is
235261 a ``pandas.DataFrame``, this list should contain the column names.
236262 """
263+ self ._validate_discrete_columns (train_data , discrete_columns )
264+
237265 if epochs is None :
238266 epochs = self ._epochs
239267 else :
0 commit comments