diff --git a/ctgan/errors.py b/ctgan/errors.py new file mode 100644 index 00000000..3f3bccac --- /dev/null +++ b/ctgan/errors.py @@ -0,0 +1,5 @@ +"""Custom errors for CTGAN.""" + + +class InvalidDataError(Exception): + """Error to raise when data is not valid.""" diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 5fdbc269..29606a34 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -11,6 +11,7 @@ from ctgan.data_sampler import DataSampler from ctgan.data_transformer import DataTransformer +from ctgan.errors import InvalidDataError from ctgan.synthesizers.base import BaseSynthesizer, random_state @@ -289,6 +290,31 @@ def _validate_discrete_columns(self, train_data, discrete_columns): if invalid_columns: raise ValueError(f'Invalid columns found: {invalid_columns}') + def _validate_null_data(self, train_data, discrete_columns): + """Check whether null values exist in continuous ``train_data``. + + Args: + train_data (numpy.ndarray or pandas.DataFrame): + Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. + discrete_columns (list-like): + List of discrete columns to be used to generate the Conditional + Vector. If ``train_data`` is a Numpy array, this list should + contain the integer indices of the columns. Otherwise, if it is + a ``pandas.DataFrame``, this list should contain the column names. + """ + if isinstance(train_data, pd.DataFrame): + continuous_cols = list(set(train_data.columns) - set(discrete_columns)) + any_nulls = train_data[continuous_cols].isna().any().any() + else: + continuous_cols = [i for i in range(train_data.shape[1]) if i not in discrete_columns] + any_nulls = pd.DataFrame(train_data)[continuous_cols].isna().any().any() + + if any_nulls: + raise InvalidDataError( + 'CTGAN does not support null values in the continuous training data. ' + 'Please remove all null values from your continuous training data.' + ) + @random_state def fit(self, train_data, discrete_columns=(), epochs=None): """Fit the CTGAN Synthesizer models to the training data. @@ -303,6 +329,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): a ``pandas.DataFrame``, this list should contain the column names. """ self._validate_discrete_columns(train_data, discrete_columns) + self._validate_null_data(train_data, discrete_columns) if epochs is None: epochs = self._epochs diff --git a/tests/integration/synthesizer/test_ctgan.py b/tests/integration/synthesizer/test_ctgan.py index 5419b094..62881c18 100644 --- a/tests/integration/synthesizer/test_ctgan.py +++ b/tests/integration/synthesizer/test_ctgan.py @@ -15,6 +15,7 @@ import pandas as pd import pytest +from ctgan.errors import InvalidDataError from ctgan.synthesizers.ctgan import CTGAN @@ -132,6 +133,25 @@ def test_categorical_nan(): assert {'b', 'c'}.issubset(values) +def test_continuous_nan(): + """Test the CTGAN with missing numerical values.""" + # Setup + data = pd.DataFrame({ + 'continuous': [np.nan, 1.0, 2.0] * 10, + 'discrete': ['a', 'b', 'c'] * 10, + }) + discrete_columns = ['discrete'] + error_message = ( + 'CTGAN does not support null values in the continuous training data. ' + 'Please remove all null values from your continuous training data.' + ) + + # Run and Assert + ctgan = CTGAN(epochs=1) + with pytest.raises(InvalidDataError, match=error_message): + ctgan.fit(data, discrete_columns) + + def test_synthesizer_sample(): """Test the CTGAN samples the correct datatype.""" data = pd.DataFrame({'discrete': np.random.choice(['a', 'b', 'c'], 100)}) diff --git a/tests/unit/synthesizer/test_ctgan.py b/tests/unit/synthesizer/test_ctgan.py index 03070c2b..8b774358 100644 --- a/tests/unit/synthesizer/test_ctgan.py +++ b/tests/unit/synthesizer/test_ctgan.py @@ -3,11 +3,13 @@ from unittest import TestCase from unittest.mock import Mock +import numpy as np import pandas as pd import pytest import torch from ctgan.data_transformer import SpanInfo +from ctgan.errors import InvalidDataError from ctgan.synthesizers.ctgan import CTGAN, Discriminator, Generator, Residual @@ -289,3 +291,42 @@ def test__validate_discrete_columns(self): ctgan = CTGAN(epochs=1) with pytest.raises(ValueError, match=r'Invalid columns found: {\'doesnt exist\'}'): ctgan.fit(data, discrete_columns) + + def test__validate_null_data(self): + """Test `_validate_null_data` with pandas and numpy data. + + Check the appropriate error is raised if null values are present in + continuous columns, both for numpy arrays and dataframes. + """ + # Setup + discrete_df = pd.DataFrame({'discrete': ['a', 'b']}) + discrete_array = np.array([['a'], ['b']]) + continuous_no_nulls_df = pd.DataFrame({'continuous': [0, 1]}) + continuous_no_nulls_array = np.array([[0], [1]]) + continuous_with_null_df = pd.DataFrame({'continuous': [1, np.nan]}) + continuous_with_null_array = np.array([[1], [np.nan]]) + ctgan = CTGAN(epochs=1) + error_message = ( + 'CTGAN does not support null values in the continuous training data. ' + 'Please remove all null values from your continuous training data.' + ) + + # Test discrete DataFrame fits without error + ctgan.fit(discrete_df, ['discrete']) + + # Test discrete array fits without error + ctgan.fit(discrete_array, [0]) + + # Test continuous DataFrame without nulls fits without error + ctgan.fit(continuous_no_nulls_df) + + # Test continuous array without nulls fits without error + ctgan.fit(continuous_no_nulls_array) + + # Test nulls in continuous columns DataFrame errors on fit + with pytest.raises(InvalidDataError, match=error_message): + ctgan.fit(continuous_with_null_df) + + # Test nulls in continuous columns array errors on fit + with pytest.raises(InvalidDataError, match=error_message): + ctgan.fit(continuous_with_null_array)