diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index e0663cd2b..8dcd190e3 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -11,9 +11,10 @@ import tqdm from rdt.transformers import FloatFormatter -from sdv._utils import MODELABLE_SDTYPES, _cast_to_iterable, _groupby_list +from sdv._utils import MODELABLE_SDTYPES, _cast_to_iterable, _groupby_list, _is_datetime_type from sdv.cag import ProgrammableConstraint from sdv.cag._utils import _validate_constraints_single_table +from sdv.constraints.utils import cast_to_datetime64 from sdv.errors import SamplingError, SynthesizerInputError from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata import Metadata @@ -37,6 +38,11 @@ LOGGER = logging.getLogger(__name__) +def _diff_and_bfill(series): + """Compute the diff of a pandas Series and backfill the first NaN.""" + return series.diff().bfill() + + class PARSynthesizer(LossValuesMixin, MissingModuleMixin, BaseSynthesizer): """Synthesizer for sequential data. @@ -310,20 +316,25 @@ def _transform_sequence_index(self, data): sequence_index_context = sequence_index_context.rename( columns={self._sequence_index: f'{self._sequence_index}.context'} ) + + if _is_datetime_type(sequence_index[self._sequence_index]): + sequence_index[self._sequence_index] = cast_to_datetime64( + sequence_index[self._sequence_index] + ).astype(np.int64) + if all(sequence_index[self._sequence_key].nunique() == 1): - sequence_index_sequence = sequence_index[[self._sequence_index]].diff().bfill() + diff_series = sequence_index[self._sequence_index].diff().bfill() else: - sequence_index_sequence = ( - sequence_index.groupby(self._sequence_key) - .apply(lambda x: x[self._sequence_index].diff().bfill()) - .droplevel(1) - .reset_index() - ) + diff_series = sequence_index.groupby(self._sequence_key, group_keys=False)[ + self._sequence_index + ].transform(_diff_and_bfill) + sequence_index_sequence = diff_series.to_frame(name=self._sequence_index) if all(sequence_index_sequence[self._sequence_index].isna()): fill_value = 0 else: fill_value = min(sequence_index_sequence[self._sequence_index].dropna()) + sequence_index_sequence = sequence_index_sequence.fillna(fill_value) data[self._sequence_index] = sequence_index_sequence[self._sequence_index].to_numpy() @@ -573,7 +584,7 @@ def _sample_from_par(self, context, sequence_length=None): pd.DataFrame({self._sequence_index: diffs}) )[self._sequence_index].to_numpy() start_index = context_columns.index(f'{self._sequence_index}.context') - start = context_values[start_index] + start = context_values.iloc[start_index] sequence[sequence_index_idx] = np.cumsum(diffs) - diffs[0] + start # Reformat as a DataFrame diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index c7bb4e332..7e761b0ca 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -14,11 +14,24 @@ from sdv.metadata.metadata import Metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.sampling import Condition -from sdv.sequential.par import PARSynthesizer +from sdv.sequential.par import PARSynthesizer, _diff_and_bfill from sdv.single_table.base import BaseSynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer +def test__diff_and_bfill(): + """Test the ``_diff_and_bfill`` method.""" + # Setup + data = pd.Series([10, 15, 20, 30]) + + # Run + result = _diff_and_bfill(data) + + # Assert + expected = pd.Series([5.0, 5.0, 5.0, 10.0]) + pd.testing.assert_series_equal(result, expected) + + class TestPARSynthesizer: def get_metadata(self, add_sequence_key=True, add_sequence_index=False): metadata = Metadata() @@ -283,6 +296,7 @@ def test_validate_context_columns_unique_per_sequence_key(self): with pytest.raises(InvalidDataError, match=err_msg): instance.validate(data) + @pytest.mark.filterwarnings('error::FutureWarning') def test__transform_sequence(self): # Setup metadata = self.get_metadata(add_sequence_index=True) @@ -310,6 +324,7 @@ def test__transform_sequence(self): assert list(par.extended_columns.keys()) == ['time'] assert par.extended_columns['time'].enforce_min_max_values is True + @pytest.mark.filterwarnings('error::FutureWarning') def test__transform_sequence_index_single_instances(self): # Setup metadata = self.get_metadata(add_sequence_index=True) @@ -332,6 +347,7 @@ def test__transform_sequence_index_single_instances(self): assert list(par.extended_columns.keys()) == ['time'] assert par.extended_columns['time'].enforce_min_max_values is True + @pytest.mark.filterwarnings('error::FutureWarning') def test__transform_sequence_index_non_unique_sequence_key(self): # Setup metadata = self.get_metadata(add_sequence_index=True) @@ -833,6 +849,7 @@ def test__sample_from_par_with_sequence_key(self, tqdm_mock): }) pd.testing.assert_frame_equal(sampled, expected_output) + @pytest.mark.filterwarnings('error::FutureWarning') @patch('sdv.sequential.par.tqdm') def test__sample_from_par_with_sequence_index(self, tqdm_mock): """Test that the method handles the sequence index properly. @@ -1245,6 +1262,9 @@ def test_sample_with_all_null_column_categorical(self): assert result['all_null_cat_col'].isna().all() assert len(result) > 0 + @pytest.mark.filterwarnings( + 'error:Series.__getitem__ treating keys as positions is deprecated:FutureWarning' + ) def test_sample_with_multiple_all_null_columns(self): """Test that sampling works correctly with multiple all-null columns.""" # Setup @@ -1257,15 +1277,21 @@ def test_sample_with_multiple_all_null_columns(self): 'all_null_col2': [np.nan] * 9, }) - metadata = Metadata() - metadata.add_table('table') - metadata.add_column('time', 'table', sdtype='datetime') - metadata.add_column('gender', 'table', sdtype='categorical') - metadata.add_column('name', 'table', sdtype='id') - metadata.add_column('measurement', 'table', sdtype='numerical') - metadata.add_column('all_null_col1', 'table', sdtype='numerical') - metadata.add_column('all_null_col2', 'table', sdtype='categorical') - metadata.set_sequence_key('name', 'table') + metadata = Metadata().load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'time': {'sdtype': 'datetime'}, + 'gender': {'sdtype': 'categorical'}, + 'name': {'sdtype': 'id'}, + 'measurement': {'sdtype': 'numerical'}, + 'all_null_col1': {'sdtype': 'numerical'}, + 'all_null_col2': {'sdtype': 'categorical'}, + }, + 'sequence_key': 'name', + } + } + }) # Run synthesizer = PARSynthesizer(metadata=metadata, epochs=1)