Skip to content

Commit aef4e07

Browse files
committed
make release-tag: Merge branch 'main' into stable
2 parents dbe8cd2 + e228a54 commit aef4e07

File tree

6 files changed

+29
-19
lines changed

6 files changed

+29
-19
lines changed

HISTORY.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# History
22

3+
## v0.9.0 - 2024-02-13
4+
5+
This release makes CTGAN sampling more efficient by saving the frequency of each categorical value.
6+
7+
### New Features
8+
9+
* Improve DataSampler efficiency - Issue [#327] ((https://github.com/sdv-dev/CTGAN/issue/327)) by @fealho
10+
311
## v0.8.0 - 2023-11-13
412

513
This release adds a progress bar that will show when setting the `verbose` parameter to `True`

ctgan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
__author__ = 'DataCebo, Inc.'
66
__email__ = 'info@sdv.dev'
7-
__version__ = '0.8.0'
7+
__version__ = '0.9.0.dev1'
88

99
from ctgan.demo import load_demo
1010
from ctgan.synthesizers.ctgan import CTGAN

ctgan/data_sampler.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class DataSampler(object):
77
"""DataSampler samples the conditional vector and corresponding data for CTGAN."""
88

99
def __init__(self, data, output_info, log_frequency):
10-
self._data = data
10+
self._data_length = len(data)
1111

1212
def is_discrete_column(column_info):
1313
return (len(column_info) == 1
@@ -115,33 +115,34 @@ def sample_original_condvec(self, batch):
115115
if self._n_discrete_columns == 0:
116116
return None
117117

118+
category_freq = self._discrete_column_category_prob.flatten()
119+
category_freq = category_freq[category_freq != 0]
120+
category_freq = category_freq / np.sum(category_freq)
121+
col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq)
118122
cond = np.zeros((batch, self._n_categories), dtype='float32')
119-
120-
for i in range(batch):
121-
row_idx = np.random.randint(0, len(self._data))
122-
col_idx = np.random.randint(0, self._n_discrete_columns)
123-
matrix_st = self._discrete_column_matrix_st[col_idx]
124-
matrix_ed = matrix_st + self._discrete_column_n_category[col_idx]
125-
pick = np.argmax(self._data[row_idx, matrix_st:matrix_ed])
126-
cond[i, pick + self._discrete_column_cond_st[col_idx]] = 1
123+
cond[np.arange(batch), col_idxs] = 1
127124

128125
return cond
129126

130-
def sample_data(self, n, col, opt):
127+
def sample_data(self, data, n, col, opt):
131128
"""Sample data from original training data satisfying the sampled conditional vector.
132129
130+
Args:
131+
data:
132+
The training data.
133133
Returns:
134-
n rows of matrix data.
134+
n:
135+
n rows of matrix data.
135136
"""
136137
if col is None:
137-
idx = np.random.randint(len(self._data), size=n)
138-
return self._data[idx]
138+
idx = np.random.randint(len(data), size=n)
139+
return data[idx]
139140

140141
idx = []
141142
for c, o in zip(col, opt):
142143
idx.append(np.random.choice(self._rid_by_cat_cols[c][o]))
143144

144-
return self._data[idx]
145+
return data[idx]
145146

146147
def dim_cond_vec(self):
147148
"""Return the total number of categories."""

ctgan/synthesizers/ctgan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
355355
condvec = self._data_sampler.sample_condvec(self._batch_size)
356356
if condvec is None:
357357
c1, m1, col, opt = None, None, None, None
358-
real = self._data_sampler.sample_data(self._batch_size, col, opt)
358+
real = self._data_sampler.sample_data(
359+
train_data, self._batch_size, col, opt)
359360
else:
360361
c1, m1, col, opt = condvec
361362
c1 = torch.from_numpy(c1).to(self._device)
@@ -365,7 +366,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
365366
perm = np.arange(self._batch_size)
366367
np.random.shuffle(perm)
367368
real = self._data_sampler.sample_data(
368-
self._batch_size, col[perm], opt[perm])
369+
train_data, self._batch_size, col[perm], opt[perm])
369370
c2 = c1[perm]
370371

371372
fake = self._generator(fakez)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.8.0
2+
current_version = 0.9.0.dev1
33
commit = True
44
tag = True
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,6 @@
119119
test_suite='tests',
120120
tests_require=tests_require,
121121
url='https://github.com/sdv-dev/CTGAN',
122-
version='0.8.0',
122+
version='0.9.0.dev1',
123123
zip_safe=False,
124124
)

0 commit comments

Comments
 (0)