Skip to content

Commit ae6ef8e

Browse files
committed
add validation method and tests
1 parent 6ed1f19 commit ae6ef8e

File tree

4 files changed

+106
-0
lines changed

4 files changed

+106
-0
lines changed

ctgan/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Custom errors for CTGAN."""
2+
3+
4+
class InvalidDataError(Exception):
5+
"""Error to raise when data is not valid."""

ctgan/synthesizers/ctgan.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ctgan.data_sampler import DataSampler
1313
from ctgan.data_transformer import DataTransformer
14+
from ctgan.errors import InvalidDataError
1415
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1516

1617

@@ -289,6 +290,31 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
289290
if invalid_columns:
290291
raise ValueError(f'Invalid columns found: {invalid_columns}')
291292

293+
def _validate_null_data(self, train_data, discrete_columns):
294+
"""Check whether null values exist in continuous ``train_data``.
295+
296+
Args:
297+
train_data (numpy.ndarray or pandas.DataFrame):
298+
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
299+
discrete_columns (list-like):
300+
List of discrete columns to be used to generate the Conditional
301+
Vector. If ``train_data`` is a Numpy array, this list should
302+
contain the integer indices of the columns. Otherwise, if it is
303+
a ``pandas.DataFrame``, this list should contain the column names.
304+
"""
305+
if isinstance(train_data, pd.DataFrame):
306+
continuous_cols = list(set(train_data.columns) - set(discrete_columns))
307+
any_nulls = train_data[continuous_cols].isna().any().any()
308+
else:
309+
continuous_cols = [i for i in range(train_data.shape[1]) if i not in discrete_columns]
310+
any_nulls = pd.DataFrame(train_data)[continuous_cols].isna().any().any()
311+
312+
if any_nulls:
313+
raise InvalidDataError(
314+
'CTGAN does not support null values in the continuous training data. '
315+
'Please remove all null values from your continuous training data.'
316+
)
317+
292318
@random_state
293319
def fit(self, train_data, discrete_columns=(), epochs=None):
294320
"""Fit the CTGAN Synthesizer models to the training data.
@@ -303,6 +329,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
303329
a ``pandas.DataFrame``, this list should contain the column names.
304330
"""
305331
self._validate_discrete_columns(train_data, discrete_columns)
332+
self._validate_null_data(train_data, discrete_columns)
306333

307334
if epochs is None:
308335
epochs = self._epochs

tests/integration/synthesizer/test_ctgan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pandas as pd
1616
import pytest
1717

18+
from ctgan.errors import InvalidDataError
1819
from ctgan.synthesizers.ctgan import CTGAN
1920

2021

@@ -132,6 +133,19 @@ def test_categorical_nan():
132133
assert {'b', 'c'}.issubset(values)
133134

134135

136+
def test_continuous_nan():
137+
"""Test the CTGAN with missing numerical values."""
138+
data = pd.DataFrame({
139+
'continuous': [np.nan, 1.0, 2.0] * 10,
140+
'discrete': ['a', 'b', 'c'] * 10,
141+
})
142+
discrete_columns = ['discrete']
143+
144+
ctgan = CTGAN(epochs=1)
145+
with pytest.raises(InvalidDataError):
146+
ctgan.fit(data, discrete_columns)
147+
148+
135149
def test_synthesizer_sample():
136150
"""Test the CTGAN samples the correct datatype."""
137151
data = pd.DataFrame({'discrete': np.random.choice(['a', 'b', 'c'], 100)})

tests/unit/synthesizer/test_ctgan.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from unittest import TestCase
44
from unittest.mock import Mock
55

6+
import numpy as np
67
import pandas as pd
78
import pytest
89
import torch
910

1011
from ctgan.data_transformer import SpanInfo
12+
from ctgan.errors import InvalidDataError
1113
from ctgan.synthesizers.ctgan import CTGAN, Discriminator, Generator, Residual
1214

1315

@@ -289,3 +291,61 @@ def test__validate_discrete_columns(self):
289291
ctgan = CTGAN(epochs=1)
290292
with pytest.raises(ValueError, match=r'Invalid columns found: {\'doesnt exist\'}'):
291293
ctgan.fit(data, discrete_columns)
294+
295+
def test__validate_null_data(self):
296+
"""Test `_validate_null_data` with pandas and numpy data.
297+
298+
Check the appropriate error is raised if null values are present in
299+
continuous columns, both for numpy arrays and dataframes.
300+
301+
Setup:
302+
- Create dataframe with a continuous column
303+
- Create numpy array with same data
304+
- Create dataframe with a discrete column
305+
- Create numpy array with a discrete column
306+
307+
Input:
308+
- train_data = 2-dimensional numpy array or a pandas.DataFrame
309+
- discrete_columns = list of strings or integers
310+
311+
Output:
312+
None
313+
314+
Side Effects:
315+
- Raises error if a continuous column contains a null value.
316+
317+
Note:
318+
- could create another function for numpy array
319+
"""
320+
# Setup
321+
discrete_df = pd.DataFrame({'discrete': ['a', 'b']})
322+
discrete_array = np.array([['a'], ['b']])
323+
continuous_no_nulls_df = pd.DataFrame({'continuous': [0, 1]})
324+
continuous_no_nulls_array = np.array([[0], [1]])
325+
continuous_with_null_df = pd.DataFrame({'continuous': [1, np.nan]})
326+
continuous_with_null_array = np.array([[1], [np.nan]])
327+
ctgan = CTGAN(epochs=1)
328+
error_message = (
329+
'CTGAN does not support null values in the continuous training data. '
330+
'Please remove all null values from your continuous training data.'
331+
)
332+
333+
# Test discrete DataFrame fits without error
334+
ctgan.fit(discrete_df, ['discrete'])
335+
336+
# Test discrete array fits without error
337+
ctgan.fit(discrete_array, [0])
338+
339+
# Test continuous DataFrame without nulls fits without error
340+
ctgan.fit(continuous_no_nulls_df)
341+
342+
# Test continuous array without nulls fits without error
343+
ctgan.fit(continuous_no_nulls_array)
344+
345+
# Test nulls in continuous columns DataFrame errors on fit
346+
with pytest.raises(InvalidDataError, match=error_message):
347+
ctgan.fit(continuous_with_null_df)
348+
349+
# Test nulls in continuous columns array errors on fit
350+
with pytest.raises(InvalidDataError, match=error_message):
351+
ctgan.fit(continuous_with_null_array)

0 commit comments

Comments
 (0)