Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import tqdm
from rdt.transformers import FloatFormatter

from sdv._utils import MODELABLE_SDTYPES, _cast_to_iterable, _groupby_list
from sdv._utils import MODELABLE_SDTYPES, _cast_to_iterable, _groupby_list, _is_datetime_type
from sdv.cag import ProgrammableConstraint
from sdv.cag._utils import _validate_constraints_single_table
from sdv.constraints.utils import cast_to_datetime64
from sdv.errors import SamplingError, SynthesizerInputError
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata import Metadata
Expand All @@ -37,6 +38,11 @@
LOGGER = logging.getLogger(__name__)


def _diff_and_bfill(series):
"""Compute the diff of a pandas Series and backfill the first NaN."""
return series.diff().bfill()


class PARSynthesizer(LossValuesMixin, MissingModuleMixin, BaseSynthesizer):
"""Synthesizer for sequential data.

Expand Down Expand Up @@ -310,20 +316,25 @@ def _transform_sequence_index(self, data):
sequence_index_context = sequence_index_context.rename(
columns={self._sequence_index: f'{self._sequence_index}.context'}
)

if _is_datetime_type(sequence_index[self._sequence_index]):
sequence_index[self._sequence_index] = cast_to_datetime64(
sequence_index[self._sequence_index]
).astype(np.int64)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before pandas 2.2, datetime columns stored as object dtypes were silently converted to float dtype. To be consistent with the fill_value used in L330 or L332, the column has to be numerical.

Let me know if this makes sense


if all(sequence_index[self._sequence_key].nunique() == 1):
sequence_index_sequence = sequence_index[[self._sequence_index]].diff().bfill()
diff_series = sequence_index[self._sequence_index].diff().bfill()
else:
sequence_index_sequence = (
sequence_index.groupby(self._sequence_key)
.apply(lambda x: x[self._sequence_index].diff().bfill())
.droplevel(1)
.reset_index()
)
diff_series = sequence_index.groupby(self._sequence_key, group_keys=False)[
self._sequence_index
].transform(_diff_and_bfill)

sequence_index_sequence = diff_series.to_frame(name=self._sequence_index)
if all(sequence_index_sequence[self._sequence_index].isna()):
fill_value = 0
else:
fill_value = min(sequence_index_sequence[self._sequence_index].dropna())

sequence_index_sequence = sequence_index_sequence.fillna(fill_value)

data[self._sequence_index] = sequence_index_sequence[self._sequence_index].to_numpy()
Expand Down Expand Up @@ -573,7 +584,7 @@ def _sample_from_par(self, context, sequence_length=None):
pd.DataFrame({self._sequence_index: diffs})
)[self._sequence_index].to_numpy()
start_index = context_columns.index(f'{self._sequence_index}.context')
start = context_values[start_index]
start = context_values.iloc[start_index]
sequence[sequence_index_idx] = np.cumsum(diffs) - diffs[0] + start

# Reformat as a DataFrame
Expand Down
46 changes: 36 additions & 10 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@
from sdv.metadata.metadata import Metadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sampling import Condition
from sdv.sequential.par import PARSynthesizer
from sdv.sequential.par import PARSynthesizer, _diff_and_bfill
from sdv.single_table.base import BaseSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer


def test__diff_and_bfill():
"""Test the ``_diff_and_bfill`` method."""
# Setup
data = pd.Series([10, 15, 20, 30])

# Run
result = _diff_and_bfill(data)

# Assert
expected = pd.Series([5.0, 5.0, 5.0, 10.0])
pd.testing.assert_series_equal(result, expected)


class TestPARSynthesizer:
def get_metadata(self, add_sequence_key=True, add_sequence_index=False):
metadata = Metadata()
Expand Down Expand Up @@ -283,6 +296,7 @@ def test_validate_context_columns_unique_per_sequence_key(self):
with pytest.raises(InvalidDataError, match=err_msg):
instance.validate(data)

@pytest.mark.filterwarnings('error::FutureWarning')
def test__transform_sequence(self):
# Setup
metadata = self.get_metadata(add_sequence_index=True)
Expand Down Expand Up @@ -310,6 +324,7 @@ def test__transform_sequence(self):
assert list(par.extended_columns.keys()) == ['time']
assert par.extended_columns['time'].enforce_min_max_values is True

@pytest.mark.filterwarnings('error::FutureWarning')
def test__transform_sequence_index_single_instances(self):
# Setup
metadata = self.get_metadata(add_sequence_index=True)
Expand All @@ -332,6 +347,7 @@ def test__transform_sequence_index_single_instances(self):
assert list(par.extended_columns.keys()) == ['time']
assert par.extended_columns['time'].enforce_min_max_values is True

@pytest.mark.filterwarnings('error::FutureWarning')
def test__transform_sequence_index_non_unique_sequence_key(self):
# Setup
metadata = self.get_metadata(add_sequence_index=True)
Expand Down Expand Up @@ -833,6 +849,7 @@ def test__sample_from_par_with_sequence_key(self, tqdm_mock):
})
pd.testing.assert_frame_equal(sampled, expected_output)

@pytest.mark.filterwarnings('error::FutureWarning')
@patch('sdv.sequential.par.tqdm')
def test__sample_from_par_with_sequence_index(self, tqdm_mock):
"""Test that the method handles the sequence index properly.
Expand Down Expand Up @@ -1245,6 +1262,9 @@ def test_sample_with_all_null_column_categorical(self):
assert result['all_null_cat_col'].isna().all()
assert len(result) > 0

@pytest.mark.filterwarnings(
'error:Series.__getitem__ treating keys as positions is deprecated:FutureWarning'
)
def test_sample_with_multiple_all_null_columns(self):
"""Test that sampling works correctly with multiple all-null columns."""
# Setup
Expand All @@ -1257,15 +1277,21 @@ def test_sample_with_multiple_all_null_columns(self):
'all_null_col2': [np.nan] * 9,
})

metadata = Metadata()
metadata.add_table('table')
metadata.add_column('time', 'table', sdtype='datetime')
metadata.add_column('gender', 'table', sdtype='categorical')
metadata.add_column('name', 'table', sdtype='id')
metadata.add_column('measurement', 'table', sdtype='numerical')
metadata.add_column('all_null_col1', 'table', sdtype='numerical')
metadata.add_column('all_null_col2', 'table', sdtype='categorical')
metadata.set_sequence_key('name', 'table')
metadata = Metadata().load_from_dict({
'tables': {
'table': {
'columns': {
'time': {'sdtype': 'datetime'},
'gender': {'sdtype': 'categorical'},
'name': {'sdtype': 'id'},
'measurement': {'sdtype': 'numerical'},
'all_null_col1': {'sdtype': 'numerical'},
'all_null_col2': {'sdtype': 'categorical'},
},
'sequence_key': 'name',
}
}
})

# Run
synthesizer = PARSynthesizer(metadata=metadata, epochs=1)
Expand Down
Loading