Skip to content

Commit c131303

Browse files
authored
Merge branch 'main' into feature-branch-download-demo
2 parents 18ece9e + 0983c8a commit c131303

File tree

11 files changed

+121
-46
lines changed

11 files changed

+121
-46
lines changed

latest_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ ctgan==0.11.0
44
deepecho==0.7.0
55
graphviz==0.21
66
numpy==2.3.3
7-
pandas==2.3.2
7+
pandas==2.3.3
88
platformdirs==4.4.0
99
rdt==1.18.1
1010
sdmetrics==0.23.0

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dependencies = [
3838
'copulas>=0.12.1',
3939
'ctgan>=0.11.0',
4040
'deepecho>=0.7.0',
41-
'rdt>=1.17.0',
41+
'rdt>=1.18.2',
4242
'sdmetrics>=0.21.0',
4343
'platformdirs>=4.0',
4444
'pyyaml>=6.0.1',

sdv/sequential/par.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
import tqdm
1212
from rdt.transformers import FloatFormatter
1313

14-
from sdv._utils import MODELABLE_SDTYPES, _cast_to_iterable, _groupby_list
14+
from sdv._utils import MODELABLE_SDTYPES, _cast_to_iterable, _groupby_list, _is_datetime_type
1515
from sdv.cag import ProgrammableConstraint
1616
from sdv.cag._utils import _validate_constraints_single_table
17+
from sdv.constraints.utils import cast_to_datetime64
1718
from sdv.errors import SamplingError, SynthesizerInputError
1819
from sdv.metadata.errors import InvalidMetadataError
1920
from sdv.metadata.metadata import Metadata
@@ -37,6 +38,11 @@
3738
LOGGER = logging.getLogger(__name__)
3839

3940

41+
def _diff_and_bfill(series):
42+
"""Compute the diff of a pandas Series and backfill the first NaN."""
43+
return series.diff().bfill()
44+
45+
4046
class PARSynthesizer(LossValuesMixin, MissingModuleMixin, BaseSynthesizer):
4147
"""Synthesizer for sequential data.
4248
@@ -310,20 +316,25 @@ def _transform_sequence_index(self, data):
310316
sequence_index_context = sequence_index_context.rename(
311317
columns={self._sequence_index: f'{self._sequence_index}.context'}
312318
)
319+
320+
if _is_datetime_type(sequence_index[self._sequence_index]):
321+
sequence_index[self._sequence_index] = cast_to_datetime64(
322+
sequence_index[self._sequence_index]
323+
).astype(np.int64)
324+
313325
if all(sequence_index[self._sequence_key].nunique() == 1):
314-
sequence_index_sequence = sequence_index[[self._sequence_index]].diff().bfill()
326+
diff_series = sequence_index[self._sequence_index].diff().bfill()
315327
else:
316-
sequence_index_sequence = (
317-
sequence_index.groupby(self._sequence_key)
318-
.apply(lambda x: x[self._sequence_index].diff().bfill())
319-
.droplevel(1)
320-
.reset_index()
321-
)
328+
diff_series = sequence_index.groupby(self._sequence_key, group_keys=False)[
329+
self._sequence_index
330+
].transform(_diff_and_bfill)
322331

332+
sequence_index_sequence = diff_series.to_frame(name=self._sequence_index)
323333
if all(sequence_index_sequence[self._sequence_index].isna()):
324334
fill_value = 0
325335
else:
326336
fill_value = min(sequence_index_sequence[self._sequence_index].dropna())
337+
327338
sequence_index_sequence = sequence_index_sequence.fillna(fill_value)
328339

329340
data[self._sequence_index] = sequence_index_sequence[self._sequence_index].to_numpy()
@@ -573,7 +584,7 @@ def _sample_from_par(self, context, sequence_length=None):
573584
pd.DataFrame({self._sequence_index: diffs})
574585
)[self._sequence_index].to_numpy()
575586
start_index = context_columns.index(f'{self._sequence_index}.context')
576-
start = context_values[start_index]
587+
start = context_values.iloc[start_index]
577588
sequence[sequence_index_idx] = np.cumsum(diffs) - diffs[0] + start
578589

579590
# Reformat as a DataFrame

sdv/single_table/_dayz_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def detect_column_parameters(data, metadata, table_name):
4242
if sdtype == 'numerical':
4343
column_parameters[column_name] = {
4444
'num_decimal_digits': learn_rounding_digits(data[column_name]),
45-
'min_value': data[column_name].min().item(),
46-
'max_value': data[column_name].max().item(),
45+
'min_value': data[column_name].min(),
46+
'max_value': data[column_name].max(),
4747
}
4848
elif sdtype == 'datetime':
4949
datetime_format = column_metadata.get('datetime_format', None)
@@ -63,13 +63,13 @@ def detect_column_parameters(data, metadata, table_name):
6363
'start_timestamp': start_timestamp,
6464
'end_timestamp': end_timestamp,
6565
}
66-
elif sdtype in ['categorical', 'boolean']:
66+
elif sdtype == 'categorical':
6767
column_parameters[column_name] = {
6868
'category_values': data[column_name].dropna().unique().tolist()
6969
}
7070

71-
column_parameters[column_name]['missing_values_proportion'] = (
72-
data[column_name].isna().mean().item()
71+
column_parameters[column_name]['missing_values_proportion'] = float(
72+
data[column_name].isna().mean()
7373
)
7474

7575
return {'columns': column_parameters}

sdv/single_table/dayz.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _validate_categorical_parameters(column_parameters, column_table_msg):
132132
raise SynthesizerProcessingError(msg)
133133

134134

135-
def _validate_missing_value_parameters(column_parameters, column_table_msg):
135+
def _validate_missing_value_parameters(column_parameters, column_table_msg, is_key_column):
136136
missing_values_proportion = column_parameters['missing_values_proportion']
137137
if not _is_numerical(missing_values_proportion) or (
138138
missing_values_proportion < 0.0 or missing_values_proportion > 1.0
@@ -142,9 +142,15 @@ def _validate_missing_value_parameters(column_parameters, column_table_msg):
142142
'must be a float between 0.0 and 1.0.'
143143
)
144144
raise SynthesizerProcessingError(msg)
145+
elif is_key_column and missing_values_proportion != 0:
146+
msg = (
147+
f"Invalid 'missing_values_proportion' parameter for {column_table_msg}. Primary "
148+
"and alternate keys must have 'missing_values_proportion' parameter set to zero."
149+
)
150+
raise SynthesizerProcessingError(msg)
145151

146152

147-
def _validate_column_parameters(table, column, column_metadata, column_parameters):
153+
def _validate_column_parameters(table, column, column_metadata, column_parameters, is_key_column):
148154
column_table_msg = f"column '{column}' in table '{table}'"
149155
sdtype = column_metadata['sdtype']
150156
sdtype_parameters = SDTYPE_TO_PARAMETERS.get(sdtype, COLUMN_PARAMETER_KEYS)
@@ -165,7 +171,7 @@ def _validate_column_parameters(table, column, column_metadata, column_parameter
165171
_validate_categorical_parameters(column_parameters, column_table_msg)
166172

167173
if 'missing_values_proportion' in column_parameters:
168-
_validate_missing_value_parameters(column_parameters, column_table_msg)
174+
_validate_missing_value_parameters(column_parameters, column_table_msg, is_key_column)
169175

170176

171177
def _validate_table_parameters(table, table_metadata, table_parameters):
@@ -186,9 +192,11 @@ def _validate_table_parameters(table, table_metadata, table_parameters):
186192
)
187193
raise SynthesizerProcessingError(msg)
188194

195+
key_columns = table_metadata._get_primary_and_alternate_keys()
189196
for column, column_parameters in table_parameters.get('columns', {}).items():
197+
is_key_column = column in key_columns
190198
_validate_column_parameters(
191-
table, column, table_metadata.columns[column], column_parameters
199+
table, column, table_metadata.columns[column], column_parameters, is_key_column
192200
)
193201

194202

tasks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from packaging.version import Version
1414

1515
COMPARISONS = {'>=': operator.ge, '>': operator.gt, '<': operator.lt, '<=': operator.le}
16-
16+
EXTERNAL_DEPENDENCY_CAPS = {
17+
'torch': '2.9.0'
18+
}
1719

1820
if not hasattr(inspect, 'getargspec'):
1921
inspect.getargspec = inspect.getfullargspec
@@ -86,6 +88,8 @@ def install_minimum(c):
8688
if minimum_versions:
8789
install_deps = ' '.join(minimum_versions)
8890
c.run(f'python -m pip install {install_deps}')
91+
for dep, cap in EXTERNAL_DEPENDENCY_CAPS.items():
92+
c.run(f'python -m pip install "{dep}<{cap}"')
8993

9094

9195
@task

tests/integration/multi_table/test_dayz.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def test_create_parameters_end_to_end(self):
2424
'guest_email': {'missing_values_proportion': 0.0},
2525
'hotel_id': {'missing_values_proportion': 0.0},
2626
'has_rewards': {
27-
'category_values': [False, True],
2827
'missing_values_proportion': 0.0,
2928
},
3029
'room_type': {

tests/integration/single_table/test_dayz.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def test_create_parameters_end_to_end(self):
2323
'columns': {
2424
'guest_email': {'missing_values_proportion': 0.0},
2525
'has_rewards': {
26-
'category_values': [False, True],
2726
'missing_values_proportion': 0.0,
2827
},
2928
'room_type': {

tests/unit/sequential/test_par.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,24 @@
1414
from sdv.metadata.metadata import Metadata
1515
from sdv.metadata.single_table import SingleTableMetadata
1616
from sdv.sampling import Condition
17-
from sdv.sequential.par import PARSynthesizer
17+
from sdv.sequential.par import PARSynthesizer, _diff_and_bfill
1818
from sdv.single_table.base import BaseSynthesizer
1919
from sdv.single_table.copulas import GaussianCopulaSynthesizer
2020

2121

22+
def test__diff_and_bfill():
23+
"""Test the ``_diff_and_bfill`` method."""
24+
# Setup
25+
data = pd.Series([10, 15, 20, 30])
26+
27+
# Run
28+
result = _diff_and_bfill(data)
29+
30+
# Assert
31+
expected = pd.Series([5.0, 5.0, 5.0, 10.0])
32+
pd.testing.assert_series_equal(result, expected)
33+
34+
2235
class TestPARSynthesizer:
2336
def get_metadata(self, add_sequence_key=True, add_sequence_index=False):
2437
metadata = Metadata()
@@ -283,6 +296,7 @@ def test_validate_context_columns_unique_per_sequence_key(self):
283296
with pytest.raises(InvalidDataError, match=err_msg):
284297
instance.validate(data)
285298

299+
@pytest.mark.filterwarnings('error::FutureWarning')
286300
def test__transform_sequence(self):
287301
# Setup
288302
metadata = self.get_metadata(add_sequence_index=True)
@@ -310,6 +324,7 @@ def test__transform_sequence(self):
310324
assert list(par.extended_columns.keys()) == ['time']
311325
assert par.extended_columns['time'].enforce_min_max_values is True
312326

327+
@pytest.mark.filterwarnings('error::FutureWarning')
313328
def test__transform_sequence_index_single_instances(self):
314329
# Setup
315330
metadata = self.get_metadata(add_sequence_index=True)
@@ -332,6 +347,7 @@ def test__transform_sequence_index_single_instances(self):
332347
assert list(par.extended_columns.keys()) == ['time']
333348
assert par.extended_columns['time'].enforce_min_max_values is True
334349

350+
@pytest.mark.filterwarnings('error::FutureWarning')
335351
def test__transform_sequence_index_non_unique_sequence_key(self):
336352
# Setup
337353
metadata = self.get_metadata(add_sequence_index=True)
@@ -833,6 +849,7 @@ def test__sample_from_par_with_sequence_key(self, tqdm_mock):
833849
})
834850
pd.testing.assert_frame_equal(sampled, expected_output)
835851

852+
@pytest.mark.filterwarnings('error::FutureWarning')
836853
@patch('sdv.sequential.par.tqdm')
837854
def test__sample_from_par_with_sequence_index(self, tqdm_mock):
838855
"""Test that the method handles the sequence index properly.
@@ -1245,6 +1262,9 @@ def test_sample_with_all_null_column_categorical(self):
12451262
assert result['all_null_cat_col'].isna().all()
12461263
assert len(result) > 0
12471264

1265+
@pytest.mark.filterwarnings(
1266+
'error:Series.__getitem__ treating keys as positions is deprecated:FutureWarning'
1267+
)
12481268
def test_sample_with_multiple_all_null_columns(self):
12491269
"""Test that sampling works correctly with multiple all-null columns."""
12501270
# Setup
@@ -1257,15 +1277,21 @@ def test_sample_with_multiple_all_null_columns(self):
12571277
'all_null_col2': [np.nan] * 9,
12581278
})
12591279

1260-
metadata = Metadata()
1261-
metadata.add_table('table')
1262-
metadata.add_column('time', 'table', sdtype='datetime')
1263-
metadata.add_column('gender', 'table', sdtype='categorical')
1264-
metadata.add_column('name', 'table', sdtype='id')
1265-
metadata.add_column('measurement', 'table', sdtype='numerical')
1266-
metadata.add_column('all_null_col1', 'table', sdtype='numerical')
1267-
metadata.add_column('all_null_col2', 'table', sdtype='categorical')
1268-
metadata.set_sequence_key('name', 'table')
1280+
metadata = Metadata().load_from_dict({
1281+
'tables': {
1282+
'table': {
1283+
'columns': {
1284+
'time': {'sdtype': 'datetime'},
1285+
'gender': {'sdtype': 'categorical'},
1286+
'name': {'sdtype': 'id'},
1287+
'measurement': {'sdtype': 'numerical'},
1288+
'all_null_col1': {'sdtype': 'numerical'},
1289+
'all_null_col2': {'sdtype': 'categorical'},
1290+
},
1291+
'sequence_key': 'name',
1292+
}
1293+
}
1294+
})
12691295

12701296
# Run
12711297
synthesizer = PARSynthesizer(metadata=metadata, epochs=1)

tests/unit/single_table/test__dayz_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,26 @@ def test_detect_column_parameter():
2727
"""Test the `detect_column_parameters` method."""
2828
# Setup
2929
data = pd.DataFrame({
30+
'pk': [0, 1, 2, 3],
3031
'num_col': [1.0, 2.5, 3.0, None],
3132
'cat_col': ['A', 'B', 'A', None],
3233
'date_col': ['2020-01-01', '2020-01-02', None, None],
3334
'date_col_2': ['2020 Jan 01', '2020 Jan 02', '2020 Jan 03', None],
35+
'alt_key': ['id0', 'id1', 'id2', 'id3'],
3436
})
3537
metadata = Metadata.load_from_dict({
3638
'tables': {
3739
'table_name': {
3840
'columns': {
41+
'pk': {'sdtype': 'id'},
3942
'num_col': {'sdtype': 'numerical'},
4043
'cat_col': {'sdtype': 'categorical'},
4144
'date_col': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
4245
'date_col_2': {'sdtype': 'datetime'},
43-
}
46+
'alt_key': {'sdtype': 'ssn'},
47+
},
48+
'primary_key': 'pk',
49+
'alternate_keys': ['alt_key'],
4450
}
4551
}
4652
})
@@ -50,6 +56,7 @@ def test_detect_column_parameter():
5056
# Assert
5157
assert result == {
5258
'columns': {
59+
'pk': {'missing_values_proportion': 0.0},
5360
'num_col': {
5461
'num_decimal_digits': 1,
5562
'min_value': 1.0,
@@ -70,6 +77,7 @@ def test_detect_column_parameter():
7077
'end_timestamp': '2020-01-03 00:00:00',
7178
'missing_values_proportion': 0.25,
7279
},
80+
'alt_key': {'missing_values_proportion': 0.0},
7381
}
7482
}
7583

0 commit comments

Comments
 (0)