Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions sdgym/synthesizers/realtabformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""REaLTabFormer integration."""

import contextlib
import inspect
import logging
from functools import partialmethod

Expand Down Expand Up @@ -37,9 +38,17 @@ def _get_trained_synthesizer(self, data, metadata):
with prevent_tqdm_output():
model_kwargs = self._MODEL_KWARGS.copy() if self._MODEL_KWARGS else {}
model = REaLTabFormer(model_type='tabular', **model_kwargs)
model.fit(data)

return model
# RealTabFormer >=0.2.3 changed the default behavior of `fit` by introducing
# `save_full_every_epoch` and `gen_kwargs`. The new defaults break the SDGym
# end-to-end test, so we set them explicitly to preserve the previous behavior.
fit_sig = inspect.signature(model.fit)
if {'save_full_every_epoch', 'gen_kwargs'} <= fit_sig.parameters.keys():
model.fit(data, save_full_every_epoch=0, gen_kwargs={})
else:
model.fit(data)

return model

def _sample_from_synthesizer(self, synthesizer, n_sample):
"""Sample synthetic data with specified sample count."""
Expand Down
29 changes: 28 additions & 1 deletion tests/unit/synthesizers/test_realtabformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for the realtabformer module."""

from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -45,6 +45,33 @@ def test__get_trained_synthesizer(self, mock_real_tab_former):
mock_model.fit.assert_called_once_with(data)
assert result == mock_model, 'Expected the trained model to be returned.'

@patch('realtabformer.REaLTabFormer')
@patch('sdgym.synthesizers.realtabformer.inspect')
def test__get_trained_synthesizer_with_fit_parameters(self, mock_inspect, mock_real_tab_former):
"""Test _get_trained_synthesizer when fit has extra parameters."""
# Setup
mock_inspect.signature.return_value.parameters = {
'save_full_every_epoch': None,
'gen_kwargs': None,
'other_param': None,
}
mock_model = Mock()
mock_real_tab_former.return_value = mock_model

data = Mock()
metadata = Mock()
synthesizer = RealTabFormerSynthesizer()
synthesizer._MODEL_KWARGS = {'epochs': 5}

# Run
result = synthesizer._get_trained_synthesizer(data, metadata)

# Assert
mock_real_tab_former.assert_called_once_with(model_type='tabular', epochs=5)
mock_model.fit.assert_called_once_with(data, save_full_every_epoch=0, gen_kwargs={})
mock_inspect.signature.assert_called_once_with(mock_model.fit)
assert result is mock_model

def test__sample_from_synthesizer(self):
"""Test _sample_from_synthesizer generates data with the specified sample size."""
# Setup
Expand Down