Skip to content

Commit a5e2e82

Browse files
committed
make release-tag: Merge branch 'master' into stable
2 parents f71b5cd + 3d42741 commit a5e2e82

File tree

10 files changed

+100
-18
lines changed

10 files changed

+100
-18
lines changed

HISTORY.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# History
22

3+
## v0.3.1 - 2021-01-27
4+
5+
### Improvements
6+
7+
* Check discrete_columns valid before fitting - [Issue #35](https://github.com/sdv-dev/CTGAN/issues/35) by @fealho
8+
9+
## Bugs fixed
10+
11+
* ValueError: max() arg is an empty sequence - [Issue #115](https://github.com/sdv-dev/CTGAN/issues/115) by @fealho
12+
313
## v0.3.0 - 2020-12-18
414

515
In this release we add a new TVAE model which was presented in the original CTGAN paper.

conda/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{% set name = 'ctgan' %}
2-
{% set version = '0.3.0' %}
2+
{% set version = '0.3.1.dev3' %}
33

44
package:
55
name: "{{ name|lower }}"

ctgan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
__author__ = 'MIT Data To AI Lab'
66
__email__ = 'dailabmit@gmail.com'
7-
__version__ = '0.3.0'
7+
__version__ = '0.3.1.dev3'
88

99
from ctgan.demo import load_demo
1010
from ctgan.synthesizers.ctgan import CTGANSynthesizer

ctgan/__main__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ def main():
6666
if args.load:
6767
model = CTGANSynthesizer.load(args.load)
6868
else:
69-
generator_dims = [int(x) for x in args.generator_dims.split(',')]
70-
discriminator_dims = [int(x) for x in args.discriminator_dims.split(',')]
69+
generator_dim = [int(x) for x in args.generator_dim.split(',')]
70+
discriminator_dim = [int(x) for x in args.discriminator_dim.split(',')]
7171
model = CTGANSynthesizer(
72-
embedding_dim=args.embedding_dim, generator_dims=generator_dims,
73-
discriminator_dims=discriminator_dims, generator_lr=args.generator_lr,
72+
embedding_dim=args.embedding_dim, generator_dim=generator_dim,
73+
discriminator_dim=discriminator_dim, generator_lr=args.generator_lr,
7474
generator_decay=args.generator_decay, discriminator_lr=args.discriminator_lr,
7575
discriminator_decay=args.discriminator_decay, batch_size=args.batch_size,
7676
epochs=args.epochs)

ctgan/data_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def is_discrete_column(column_info):
4141
# Prepare an interval matrix for efficiently sample conditional vector
4242
max_category = max(
4343
[column_info[0].dim for column_info in output_info
44-
if is_discrete_column(column_info)])
44+
if is_discrete_column(column_info)], default=0)
4545

4646
self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
4747
self._discrete_column_n_category = np.zeros(
@@ -133,7 +133,7 @@ def sample_data(self, n, col, opt):
133133
n rows of matrix data.
134134
"""
135135
if col is None:
136-
idx = np.random.randint(len(self._data), n)
136+
idx = np.random.randint(len(self._data), size=n)
137137
return self._data[idx]
138138

139139
idx = []

ctgan/synthesizers/ctgan.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22

33
import numpy as np
4+
import pandas as pd
45
import torch
56
from packaging import version
67
from torch import optim
@@ -13,13 +14,13 @@
1314

1415
class Discriminator(Module):
1516

16-
def __init__(self, input_dim, dis_dims, pack=10):
17+
def __init__(self, input_dim, discriminator_dim, pack=10):
1718
super(Discriminator, self).__init__()
1819
dim = input_dim * pack
1920
self.pack = pack
2021
self.packdim = dim
2122
seq = []
22-
for item in list(dis_dims):
23+
for item in list(discriminator_dim):
2324
seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
2425
dim = item
2526

@@ -222,6 +223,31 @@ def _cond_loss(self, data, c, m):
222223

223224
return (loss * m).sum() / data.size()[0]
224225

226+
def _validate_discrete_columns(self, train_data, discrete_columns):
227+
"""Check whether ``discrete_columns`` exists in ``train_data``.
228+
229+
Args:
230+
train_data (numpy.ndarray or pandas.DataFrame):
231+
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
232+
discrete_columns (list-like):
233+
List of discrete columns to be used to generate the Conditional
234+
Vector. If ``train_data`` is a Numpy array, this list should
235+
contain the integer indices of the columns. Otherwise, if it is
236+
a ``pandas.DataFrame``, this list should contain the column names.
237+
"""
238+
if isinstance(train_data, pd.DataFrame):
239+
invalid_columns = set(discrete_columns) - set(train_data.columns)
240+
elif isinstance(train_data, np.ndarray):
241+
invalid_columns = []
242+
for column in discrete_columns:
243+
if column < 0 or column >= train_data.shape[1]:
244+
invalid_columns.append(column)
245+
else:
246+
raise TypeError('``train_data`` should be either pd.DataFrame or np.array.')
247+
248+
if invalid_columns:
249+
raise ValueError('Invalid columns found: {}'.format(invalid_columns))
250+
225251
def fit(self, train_data, discrete_columns=tuple(), epochs=None):
226252
"""Fit the CTGAN Synthesizer models to the training data.
227253
@@ -234,6 +260,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
234260
contain the integer indices of the columns. Otherwise, if it is
235261
a ``pandas.DataFrame``, this list should contain the column names.
236262
"""
263+
self._validate_discrete_columns(train_data, discrete_columns)
264+
237265
if epochs is None:
238266
epochs = self._epochs
239267
else:

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.3.0
2+
current_version = 0.3.1.dev3
33
commit = True
44
tag = True
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
'torch<2,>=1.0',
1616
'torchvision<1,>=0.4.2',
1717
'scikit-learn<0.24,>=0.21',
18-
'rdt>=0.2.7,<0.3',
1918
'numpy<2,>=1.17.4',
2019
'pandas<1.1.5,>=0.24',
20+
'rdt>=0.2.7,<0.4',
2121
'packaging',
2222
]
2323

@@ -99,6 +99,6 @@
9999
test_suite='tests',
100100
tests_require=tests_require,
101101
url='https://github.com/sdv-dev/CTGAN',
102-
version='0.3.0',
102+
version='0.3.1.dev3',
103103
zip_safe=False,
104104
)

tests/integration/test_ctgan.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,30 @@
99
model are not checked.
1010
"""
1111

12+
import tempfile as tf
13+
1214
import numpy as np
1315
import pandas as pd
16+
import pytest
1417

1518
from ctgan.synthesizers.ctgan import CTGANSynthesizer
1619

1720

21+
def test_ctgan_no_categoricals():
22+
data = pd.DataFrame({
23+
'continuous': np.random.random(1000)
24+
})
25+
26+
ctgan = CTGANSynthesizer(epochs=1)
27+
ctgan.fit(data, [])
28+
29+
sampled = ctgan.sample(100)
30+
31+
assert sampled.shape == (100, 1)
32+
assert isinstance(sampled, pd.DataFrame)
33+
assert set(sampled.columns) == {'continuous'}
34+
35+
1836
def test_ctgan_dataframe():
1937
data = pd.DataFrame({
2038
'continuous': np.random.random(100),
@@ -120,10 +138,33 @@ def test_save_load():
120138

121139
ctgan = CTGANSynthesizer(epochs=1)
122140
ctgan.fit(data, discrete_columns)
123-
ctgan.save("test_ctgan.pkl")
124141

125-
ctgan = CTGANSynthesizer.load("test_ctgan.pkl")
142+
with tf.TemporaryDirectory() as temporary_directory:
143+
ctgan.save(temporary_directory + "test_tvae.pkl")
144+
ctgan = CTGANSynthesizer.load(temporary_directory + "test_tvae.pkl")
126145

127146
sampled = ctgan.sample(1000)
128147
assert set(sampled.columns) == {'continuous', 'discrete'}
129148
assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
149+
150+
151+
def test_wrong_discrete_columns_dataframe():
152+
data = pd.DataFrame({
153+
'discrete': ['a', 'b']
154+
})
155+
discrete_columns = ['b', 'c']
156+
157+
ctgan = CTGANSynthesizer(epochs=1)
158+
with pytest.raises(ValueError):
159+
ctgan.fit(data, discrete_columns)
160+
161+
162+
def test_wrong_discrete_columns_numpy():
163+
data = pd.DataFrame({
164+
'discrete': ['a', 'b']
165+
})
166+
discrete_columns = [0, 1]
167+
168+
ctgan = CTGANSynthesizer(epochs=1)
169+
with pytest.raises(ValueError):
170+
ctgan.fit(data.to_numpy(), discrete_columns)

tests/integration/test_tvae.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
model are not checked.
1010
"""
1111

12+
import tempfile as tf
13+
1214
import numpy as np
1315
import pandas as pd
1416

@@ -70,11 +72,12 @@ def test_save_load():
7072
})
7173
discrete_columns = ['discrete']
7274

73-
tvae = TVAESynthesizer(epochs=1)
75+
tvae = TVAESynthesizer(epochs=10)
7476
tvae.fit(data, discrete_columns)
75-
tvae.save("test_tvae.pkl")
7677

77-
tvae = TVAESynthesizer.load("test_tvae.pkl")
78+
with tf.TemporaryDirectory() as temporary_directory:
79+
tvae.save(temporary_directory + "test_tvae.pkl")
80+
tvae = TVAESynthesizer.load(temporary_directory + "test_tvae.pkl")
7881

7982
sampled = tvae.sample(1000)
8083
assert set(sampled.columns) == {'continuous', 'discrete'}

0 commit comments

Comments
 (0)