Skip to content

Commit a7eb1f8

Browse files
authored
Reordering context columns in PARSynthesizer (#2726)
1 parent c9f495a commit a7eb1f8

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

sdv/sequential/par.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ def _get_id_context_columns(self):
378378
if self._get_table_metadata().columns[col]['sdtype'] not in MODELABLE_SDTYPES
379379
]
380380

381+
def _reorder_context_columns(self, context_columns, timeseries_data):
382+
order = {column: i for i, column in enumerate(timeseries_data.columns)}
383+
return sorted(context_columns, key=lambda x: order.get(x, float('inf')))
384+
381385
def _preprocess(self, data):
382386
"""Transform the raw data to numerical space.
383387
@@ -539,6 +543,8 @@ def _fit(self, processed_data):
539543
pandas.DataFrame containing both the sequences,
540544
the entity columns and the context columns.
541545
"""
546+
self.context_columns = self._reorder_context_columns(self.context_columns, processed_data)
547+
542548
if self._sequence_key:
543549
self._fit_context_model(processed_data)
544550

tests/integration/sequential/test_par.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,3 +1002,49 @@ def test_add_constraints_with_context_columns():
10021002
synthesizer.fit(data)
10031003
samples = synthesizer.sample(5)
10041004
synthesizer.validate(samples)
1005+
1006+
1007+
def test_par_context_columns_invariance():
1008+
"""Test par is invariate to the order of context columns."""
1009+
# Setup
1010+
data = pd.DataFrame(
1011+
data={
1012+
'sequence': ['id-0'] * 3 + ['id-1'] * 4 + ['id-2'] * 3,
1013+
'context1': ['M'] * 3 + ['F'] * 4 + ['M'] * 3,
1014+
'context2': [12.0] * 3 + [np.nan] * 4 + [34.0] * 3,
1015+
'seq1': [12, 34, 12, 78, 12, 56, 34, 78, 12, 67],
1016+
'seq2': ['Yes', 'Yes', 'No', 'No', 'No', 'No', 'Yes', 'Yes', 'No', 'No'],
1017+
}
1018+
)
1019+
1020+
metadata = Metadata.load_from_dict({
1021+
'tables': {
1022+
'table': {
1023+
'columns': {
1024+
'sequence': {'sdtype': 'id'},
1025+
'context1': {'sdtype': 'categorical'},
1026+
'context2': {'sdtype': 'numerical'},
1027+
'seq1': {'sdtype': 'numerical'},
1028+
'seq2': {'sdtype': 'categorical'},
1029+
},
1030+
'sequence_key': 'sequence',
1031+
}
1032+
}
1033+
})
1034+
1035+
synthesizer1 = PARSynthesizer(metadata, epochs=1, context_columns=['context1', 'context2'])
1036+
1037+
synthesizer2 = PARSynthesizer(metadata, epochs=1, context_columns=['context2', 'context1'])
1038+
1039+
# Run
1040+
synthesizer1.fit(data)
1041+
samples1 = synthesizer1.sample(num_sequences=3, sequence_length=2)
1042+
1043+
synthesizer2.fit(data)
1044+
samples2 = synthesizer2.sample(num_sequences=3, sequence_length=2)
1045+
1046+
# Assert
1047+
assert samples1.shape == samples2.shape
1048+
assert samples1.columns.equals(samples2.columns)
1049+
synthesizer1.validate(samples2)
1050+
synthesizer2.validate(samples1)

tests/unit/sequential/test_par.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,38 @@ def test_update_transformers_context_column(self):
451451
with pytest.raises(SynthesizerInputError, match=err_msg):
452452
instance.update_transformers({'time': FloatFormatter()})
453453

454+
def test__fit_reorder_context_columns_incorrect_order(self):
455+
"""Test that the context columns are reordered according to data."""
456+
# Setup
457+
metadata = self.get_metadata()
458+
data = self.get_data()
459+
460+
data.insert(1, 'height', [160, 170, 180])
461+
metadata.add_column('height', 'table', sdtype='numerical')
462+
instance = PARSynthesizer(metadata, context_columns=['gender', 'height'])
463+
464+
# Run
465+
instance.fit(data)
466+
467+
# Assert
468+
assert instance.context_columns == ['height', 'gender']
469+
470+
def test__fit_reorder_context_columns_correct_order(self):
471+
"""Test that the context columns is still the same order."""
472+
# Setup
473+
metadata = self.get_metadata()
474+
data = self.get_data()
475+
476+
data.insert(2, 'height', [160, 170, 180])
477+
metadata.add_column('height', 'table', sdtype='numerical')
478+
instance = PARSynthesizer(metadata, context_columns=['gender', 'height'])
479+
480+
# Run
481+
instance.fit(data)
482+
483+
# Assert
484+
assert instance.context_columns == ['gender', 'height']
485+
454486
@patch('sdv.sequential.par.GaussianCopulaSynthesizer')
455487
def test__fit_context_model_with_context_columns(self, gaussian_copula_mock):
456488
"""Test that the method fits a synthesizer to the context columns.

0 commit comments

Comments
 (0)