Skip to content

Commit 50ac89b

Browse files
authored
Add REalTabFormer to supported synthesizers (#360)
1 parent fcc853a commit 50ac89b

File tree

8 files changed

+166
-5
lines changed

8 files changed

+166
-5
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,6 @@ tmp/
112112
dask-worker-space
113113
scripts/runs
114114
scripts/datasets
115+
116+
# ReaLTabFormer
117+
rtf_checkpoints/

pyproject.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dependencies = [
2727
'cloudpickle>=2.1.0',
2828
'compress-pickle>=1.2.0',
2929
'humanfriendly>=8.2',
30-
"numpy>=1.21.0;python_version<'3.10'",
30+
"numpy>=1.21.6;python_version<'3.10'",
3131
"numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
3232
"numpy>=1.26.0;python_version>='3.12'",
3333
"pandas>=1.4.0;python_version<'3.11'",
@@ -42,10 +42,10 @@ dependencies = [
4242
"scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'",
4343
"scipy>=1.12.0;python_version>='3.12'",
4444
'tabulate>=0.8.3,<0.9',
45-
"torch>=1.9.0;python_version<'3.10'",
45+
"torch>=1.12.1;python_version<'3.10'",
4646
"torch>=2.0.0;python_version>='3.10' and python_version<'3.12'",
4747
"torch>=2.2.0;python_version>='3.12'",
48-
'tqdm>=4.29',
48+
'tqdm>=4.66.3',
4949
'XlsxWriter>=1.2.8',
5050
'rdt>=1.13.1',
5151
'sdmetrics>=0.17.0',
@@ -64,7 +64,9 @@ sdgym = { main = 'sdgym.cli.__main__:main' }
6464

6565
[project.optional-dependencies]
6666
dask = ['dask', 'distributed']
67+
realtabformer = ['realtabformer>=0.2.1', 'transformers<4.46']
6768
test = [
69+
'sdgym[realtabformer]',
6870
'pytest>=6.2.5',
6971
'pytest-cov>=2.6.0',
7072
'jupyter>=1.0.0,<2',
@@ -231,4 +233,4 @@ convention = "google"
231233

232234
[tool.ruff.lint.pycodestyle]
233235
max-doc-length = 100
234-
max-line-length = 100
236+
max-line-length = 100

sdgym/benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ def benchmark_single_table(
746746
- ``CTGANSynthesizer``
747747
- ``CopulaGANSynthesizer``
748748
- ``TVAESynthesizer``
749+
- ``RealTabFormerSynthesizer``
749750
750751
custom_synthesizers (list[class] or ``None``):
751752
A list of custom synthesizer classes to use. These can be completely custom or

sdgym/synthesizers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from sdgym.synthesizers.identity import DataIdentity
1111
from sdgym.synthesizers.column import ColumnSynthesizer
12+
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
1213
from sdgym.synthesizers.sdv import (
1314
CopulaGANSynthesizer,
1415
CTGANSynthesizer,
@@ -38,4 +39,5 @@
3839
'create_sdv_synthesizer_variant',
3940
'create_sequential_synthesizer',
4041
'SYNTHESIZER_MAPPING',
42+
'RealTabFormerSynthesizer',
4143
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""REaLTabFormer integration."""
2+
3+
import contextlib
4+
import logging
5+
from functools import partialmethod
6+
7+
import tqdm
8+
9+
from sdgym.synthesizers.base import BaselineSynthesizer
10+
11+
12+
@contextlib.contextmanager
13+
def prevent_tqdm_output():
14+
"""Temporarily disables tqdm m."""
15+
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
16+
try:
17+
yield
18+
finally:
19+
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)
20+
21+
22+
class RealTabFormerSynthesizer(BaselineSynthesizer):
23+
"""Custom wrapper for the REaLTabFormer synthesizer to make it work with SDGym."""
24+
25+
LOGGER = logging.getLogger(__name__)
26+
27+
def _get_trained_synthesizer(self, data, metadata):
28+
try:
29+
from realtabformer import REaLTabFormer
30+
except Exception as exception:
31+
raise ValueError(
32+
"In order to use 'RealTabFormerSynthesizer' you have to install the extra"
33+
" dependencies by running pip install sdgym['realtabformer'] "
34+
) from exception
35+
36+
with prevent_tqdm_output():
37+
model = REaLTabFormer(model_type='tabular')
38+
model.fit(data, device='cpu')
39+
40+
return model
41+
42+
def _sample_from_synthesizer(self, synthesizer, n_sample):
43+
"""Sample synthetic data with specified sample count."""
44+
return synthesizer.sample(n_sample, device='cpu')
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import sys
2+
3+
import pytest
4+
5+
from sdgym import load_dataset
6+
from sdgym.synthesizers import RealTabFormerSynthesizer
7+
8+
9+
@pytest.mark.skipif(sys.platform.startswith('darwin'), reason='Test not supported on github MacOS')
10+
def test_realtabformer_end_to_end():
11+
"""Test it without metrics."""
12+
# Setup
13+
data, metadata_dict = load_dataset(
14+
'single_table', 'student_placements', limit_dataset_size=False
15+
)
16+
realtabformer_instance = RealTabFormerSynthesizer()
17+
18+
# Run
19+
trained_synthesizer = realtabformer_instance.get_trained_synthesizer(data, metadata_dict)
20+
sampled_data = realtabformer_instance.sample_from_synthesizer(trained_synthesizer, n_samples=10)
21+
22+
# Assert
23+
assert sampled_data.shape[1] == data.shape[1], (
24+
f'Sampled data shape {sampled_data.shape} does not match original data shape {data.shape}'
25+
)
26+
27+
assert set(sampled_data.columns) == set(data.columns)

tests/integration/test_benchmark.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import io
55
import re
6+
import sys
67
import time
78

89
import numpy as np
@@ -49,6 +50,25 @@ def test_benchmark_single_table_basic_synthsizers():
4950
] == quality_scores.index.tolist()
5051

5152

53+
@pytest.mark.skipif(sys.platform.startswith('darwin'), reason='Test not supported on github MacOS')
54+
def test_benchmark_single_table_realtabformer_no_metrics():
55+
"""Test it without metrics."""
56+
# Run
57+
output = sdgym.benchmark_single_table(
58+
synthesizers=['RealTabFormerSynthesizer'],
59+
sdv_datasets=['student_placements'],
60+
sdmetrics=[],
61+
)
62+
63+
# Assert
64+
train_time = output['Train_Time'][0]
65+
sample_time = output['Sample_Time'][0]
66+
assert isinstance(train_time, (int, float, complex)), 'Train_Time is not numerical'
67+
assert isinstance(sample_time, (int, float, complex)), 'Sample_Time is not numerical'
68+
assert train_time >= 0
69+
assert sample_time >= 0
70+
71+
5272
def test_benchmark_single_table_no_metrics():
5373
"""Test it without metrics."""
5474
# Run
@@ -62,7 +82,6 @@ def test_benchmark_single_table_no_metrics():
6282
assert not output.empty
6383
assert 'Train_Time' in output
6484
assert 'Sample_Time' in output
65-
6685
# Expect no metric columns.
6786
assert len(output.columns) == 10
6887

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Tests for the realtabformer module."""
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
import numpy as np
6+
import pandas as pd
7+
import pytest
8+
9+
from sdgym.synthesizers import RealTabFormerSynthesizer
10+
11+
12+
@pytest.fixture
13+
def sample_data():
14+
"""Provide sample data for testing."""
15+
n_samples = 10
16+
num_values = np.random.normal(size=n_samples)
17+
18+
return pd.DataFrame({
19+
'num': num_values,
20+
})
21+
22+
23+
class TestRealTabFormerSynthesizer:
24+
"""Unit tests for RealTabFormerSynthesizer integration with SDGym."""
25+
26+
@patch('realtabformer.REaLTabFormer')
27+
def test__get_trained_synthesizer(self, mock_real_tab_former):
28+
"""Test _get_trained_synthesizer
29+
30+
Initializes REaLTabFormer and fits REaLTabFormer with
31+
correct parameters.
32+
"""
33+
# Setup
34+
mock_model = MagicMock()
35+
mock_real_tab_former.return_value = mock_model
36+
data = MagicMock()
37+
metadata = MagicMock()
38+
synthesizer = RealTabFormerSynthesizer()
39+
40+
# Run
41+
result = synthesizer._get_trained_synthesizer(data, metadata)
42+
43+
# Assert
44+
mock_real_tab_former.assert_called_once_with(model_type='tabular')
45+
mock_model.fit.assert_called_once_with(data, device='cpu')
46+
assert result == mock_model, 'Expected the trained model to be returned.'
47+
48+
def test__sample_from_synthesizer(self):
49+
"""Test _sample_from_synthesizer generates data with the specified sample size."""
50+
# Setup
51+
trained_model = MagicMock()
52+
trained_model.sample.return_value = MagicMock(shape=(10, 5)) # Mock sample data shape
53+
n_sample = 10
54+
synthesizer = RealTabFormerSynthesizer()
55+
56+
# Run
57+
synthetic_data = synthesizer._sample_from_synthesizer(trained_model, n_sample)
58+
59+
# Assert
60+
trained_model.sample.assert_called_once_with(n_sample, device='cpu')
61+
assert synthetic_data.shape[0] == n_sample, (
62+
f'Expected {n_sample} rows, but got {synthetic_data.shape[0]}'
63+
)

0 commit comments

Comments
 (0)