Skip to content

Commit f88e933

Browse files
committed
add unit and integration tests
1 parent 6388936 commit f88e933

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

tests/integration/sequential/test_par.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,3 +1002,49 @@ def test_add_constraints_with_context_columns():
10021002
synthesizer.fit(data)
10031003
samples = synthesizer.sample(5)
10041004
synthesizer.validate(samples)
1005+
1006+
1007+
def test_par_context_columns_invariance():
1008+
"""Test par is invariate to the order of context columns."""
1009+
# Setup
1010+
data = pd.DataFrame(
1011+
data={
1012+
'sequence': ['id-0'] * 3 + ['id-1'] * 4 + ['id-2'] * 3,
1013+
'context1': ['M'] * 3 + ['F'] * 4 + ['M'] * 3,
1014+
'context2': [12.0] * 3 + [np.nan] * 4 + [34.0] * 3,
1015+
'seq1': [12, 34, 12, 78, 12, 56, 34, 78, 12, 67],
1016+
'seq2': ['Yes', 'Yes', 'No', 'No', 'No', 'No', 'Yes', 'Yes', 'No', 'No'],
1017+
}
1018+
)
1019+
1020+
metadata = Metadata.load_from_dict({
1021+
'tables': {
1022+
'table': {
1023+
'columns': {
1024+
'sequence': {'sdtype': 'id'},
1025+
'context1': {'sdtype': 'categorical'},
1026+
'context2': {'sdtype': 'numerical'},
1027+
'seq1': {'sdtype': 'numerical'},
1028+
'seq2': {'sdtype': 'categorical'},
1029+
},
1030+
'sequence_key': 'sequence',
1031+
}
1032+
}
1033+
})
1034+
1035+
synthesizer1 = PARSynthesizer(metadata, epochs=1, context_columns=['context1', 'context2'])
1036+
1037+
synthesizer2 = PARSynthesizer(metadata, epochs=1, context_columns=['context2', 'context1'])
1038+
1039+
# Run
1040+
synthesizer1.fit(data)
1041+
samples1 = synthesizer1.sample(num_sequences=3, sequence_length=2)
1042+
1043+
synthesizer2.fit(data)
1044+
samples2 = synthesizer2.sample(num_sequences=3, sequence_length=2)
1045+
1046+
# Assert
1047+
assert samples1.shape == samples2.shape
1048+
assert samples1.columns.equals(samples2.columns)
1049+
synthesizer1.validate(samples2)
1050+
synthesizer2.validate(samples1)

tests/unit/sequential/test_par.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,38 @@ def test_update_transformers_context_column(self):
435435
with pytest.raises(SynthesizerInputError, match=err_msg):
436436
instance.update_transformers({'time': FloatFormatter()})
437437

438+
def test__fit_reorder_context_columns_incorrect_order(self):
439+
"""Test that the context columns are reordered according to data."""
440+
# Setup
441+
metadata = self.get_metadata()
442+
data = self.get_data()
443+
444+
data.insert(1, 'height', [160, 170, 180])
445+
metadata.add_column('height', 'table', sdtype='numerical')
446+
instance = PARSynthesizer(metadata, context_columns=['gender', 'height'])
447+
448+
# Run
449+
instance.fit(data)
450+
451+
# Assert
452+
assert instance.context_columns == ['height', 'gender']
453+
454+
def test__fit_reorder_context_columns_correct_order(self):
455+
"""Test that the context columns is still the same order."""
456+
# Setup
457+
metadata = self.get_metadata()
458+
data = self.get_data()
459+
460+
data.insert(2, 'height', [160, 170, 180])
461+
metadata.add_column('height', 'table', sdtype='numerical')
462+
instance = PARSynthesizer(metadata, context_columns=['gender', 'height'])
463+
464+
# Run
465+
instance.fit(data)
466+
467+
# Assert
468+
assert instance.context_columns == ['gender', 'height']
469+
438470
@patch('sdv.sequential.par.GaussianCopulaSynthesizer')
439471
def test__fit_context_model_with_context_columns(self, gaussian_copula_mock):
440472
"""Test that the method fits a synthesizer to the context columns.

0 commit comments

Comments
 (0)