Skip to content

Commit 0047a8d

Browse files
authored
Allow users to estimate parameters for DayZSynthesizer (#2670)
1 parent 6482447 commit 0047a8d

File tree

12 files changed

+707
-1
lines changed

12 files changed

+707
-1
lines changed

sdv/multi_table/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Synthesizers for Multi Table data."""
22

33
from sdv.multi_table.hma import HMASynthesizer
4+
from sdv.multi_table.dayz import DayZSynthesizer
45

5-
__all__ = ('HMASynthesizer',)
6+
__all__ = ('HMASynthesizer', 'DayZSynthesizer')

sdv/multi_table/_dayz_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
3+
import pandas as pd
4+
5+
from sdv.single_table.dayz import create_parameters
6+
7+
8+
def detect_relationship_parameters(data, metadata):
9+
"""Detect all relationship-level for the DayZ parameters.
10+
11+
The relationship-level parameters are:
12+
- The min and max cardinality
13+
14+
Args:
15+
data (dict[str, pd.DataFrame]): The input data.
16+
metadata (Metadata): The metadata object.
17+
18+
Returns:
19+
dict: A dictionary containing the detected parameters.
20+
"""
21+
relationship_parameters = {}
22+
for relationship in metadata.relationships:
23+
rel_tuple = (
24+
relationship['parent_table_name'],
25+
relationship['child_table_name'],
26+
relationship['parent_primary_key'],
27+
relationship['child_foreign_key'],
28+
)
29+
cardinality_table = pd.DataFrame(index=data[rel_tuple[0]][rel_tuple[2]].copy())
30+
cardinality_table['cardinality'] = data[rel_tuple[1]][rel_tuple[3]].value_counts()
31+
cardinality_table = cardinality_table.fillna(0)
32+
relationship_parameters[json.dumps(rel_tuple)] = {
33+
'min_cardinality': cardinality_table['cardinality'].min(),
34+
'max_cardinality': cardinality_table['cardinality'].max(),
35+
}
36+
37+
return relationship_parameters
38+
39+
40+
def create_parameters_multi_table(data, metadata, output_filename):
41+
"""Create parameters for the DayZSynthesizer."""
42+
parameters = create_parameters(data, metadata, None)
43+
parameters['relationships'] = detect_relationship_parameters(data, metadata)
44+
if output_filename:
45+
with open(output_filename, 'w') as f:
46+
json.dump(parameters, f, indent=4)
47+
48+
return parameters

sdv/multi_table/dayz.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Multi-Table DayZ parameter detection and creation."""
2+
3+
from sdv.errors import SynthesizerInputError
4+
from sdv.multi_table._dayz_utils import create_parameters_multi_table
5+
6+
7+
class DayZSynthesizer:
8+
"""Multi-Table DayZSynthesizer for public SDV."""
9+
10+
def __init__(self, metadata, locales=['en_US']):
11+
raise SynthesizerInputError(
12+
"Only the 'DayZSynthesizer.create_parameters' is a SDV public feature. "
13+
'To define and use and use a DayZSynthesizer object you must have SDV-Enterprise.'
14+
)
15+
16+
@classmethod
17+
def create_parameters(cls, data, metadata, output_filename=None):
18+
"""Create parameters for the DayZSynthesizer.
19+
20+
Args:
21+
data (dict[str, pd.DataFrame]): The input data.
22+
metadata (Metadata): The metadata object.
23+
output_filename (str, optional): The output filename for the parameters.
24+
25+
Returns:
26+
dict: The created parameters.
27+
"""
28+
return create_parameters_multi_table(data, metadata, output_filename)

sdv/single_table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from sdv.single_table.copulagan import CopulaGANSynthesizer
44
from sdv.single_table.copulas import GaussianCopulaSynthesizer
55
from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer
6+
from sdv.single_table.dayz import DayZSynthesizer
67

78
__all__ = (
89
'GaussianCopulaSynthesizer',
910
'CTGANSynthesizer',
1011
'TVAESynthesizer',
1112
'CopulaGANSynthesizer',
13+
'DayZSynthesizer',
1214
)

sdv/single_table/_dayz_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import json
2+
3+
import pandas as pd
4+
from rdt.transformers.utils import learn_rounding_digits
5+
6+
7+
def detect_table_parameters(data):
8+
"""Detect all table-level Dayz parameters.
9+
10+
- Detect the `num_rows` of the table.
11+
12+
Args:
13+
data (pd.DataFrame): The input data.
14+
15+
Returns:
16+
dict: A dictionary containing the detected parameters.
17+
"""
18+
return {'num_rows': len(data)}
19+
20+
21+
def detect_column_parameters(data, metadata, table_name):
22+
"""Detect all column-level Dayz parameters.
23+
24+
The column-level parameters are:
25+
- The missing value proportion
26+
- The boundaries for numerical and datetime columns
27+
- The categories for categorical columns
28+
- The 'num_decimal_digits' for numerical columns
29+
30+
Args:
31+
data (pd.DataFrame): The input data.
32+
metadata (Metadata): The metadata object.
33+
34+
Returns:
35+
dict: A dictionary containing the detected parameters.
36+
"""
37+
table_metadata = metadata.tables[table_name]
38+
column_parameters = {}
39+
for column_name, column_metadata in table_metadata.columns.items():
40+
column_parameters[column_name] = {}
41+
sdtype = column_metadata['sdtype']
42+
if sdtype == 'numerical':
43+
column_parameters[column_name] = {
44+
'num_decimal_digits': learn_rounding_digits(data[column_name]),
45+
'min_value': data[column_name].min().item(),
46+
'max_value': data[column_name].max().item(),
47+
}
48+
elif sdtype == 'datetime':
49+
datetime_format = column_metadata.get('datetime_format', None)
50+
if datetime_format:
51+
datetime_column = pd.to_datetime(
52+
data[column_name], format=datetime_format, errors='coerce'
53+
)
54+
start_timestamp = datetime_column.min().strftime(datetime_format)
55+
end_timestamp = datetime_column.max().strftime(datetime_format)
56+
57+
else:
58+
datetime_column = pd.to_datetime(data[column_name], errors='coerce')
59+
start_timestamp = str(datetime_column.min())
60+
end_timestamp = str(datetime_column.max())
61+
62+
column_parameters[column_name] = {
63+
'start_timestamp': start_timestamp,
64+
'end_timestamp': end_timestamp,
65+
}
66+
elif sdtype in ['categorical', 'boolean']:
67+
column_parameters[column_name] = {
68+
'category_values': data[column_name].dropna().unique().tolist()
69+
}
70+
71+
column_parameters[column_name]['missing_values_proportion'] = (
72+
data[column_name].isna().mean().item()
73+
)
74+
75+
return {'columns': column_parameters}
76+
77+
78+
def create_parameters(data, metadata, output_filename):
79+
"""Detect and create a parameter dict for the DayZ model."""
80+
metadata.validate()
81+
datas = data if isinstance(data, dict) else {metadata._get_single_table_name(): data}
82+
metadata.validate_data(datas)
83+
parameters = {'DAYZ_SPEC_VERSION': 'V1', 'tables': {}}
84+
for table_name, table_data in datas.items():
85+
parameters['tables'][table_name] = {}
86+
parameters['tables'][table_name].update(detect_table_parameters(table_data))
87+
parameters['tables'][table_name].update(
88+
detect_column_parameters(table_data, metadata, table_name)
89+
)
90+
91+
if output_filename:
92+
with open(output_filename, 'w') as f:
93+
json.dump(parameters, f, indent=4)
94+
95+
return parameters

sdv/single_table/dayz.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""DayZ parameter detection and creation."""
2+
3+
from sdv.errors import SynthesizerInputError
4+
from sdv.single_table._dayz_utils import create_parameters
5+
6+
7+
class DayZSynthesizer:
8+
"""Single-Table DayZSynthesizer for public SDV."""
9+
10+
def __init__(self, metadata, locales=['en_US']):
11+
raise SynthesizerInputError(
12+
"Only the 'DayZSynthesizer.create_parameters' is a SDV public feature. "
13+
'To define and use and use a DayZSynthesizer object you must have SDV-Enterprise.'
14+
)
15+
16+
@classmethod
17+
def create_parameters(cls, data, metadata, output_filename=None):
18+
"""Create parameters for the DayZ synthesizer.
19+
20+
Args:
21+
data (pd.DataFrame): The input data.
22+
metadata (Metadata): The metadata object.
23+
output_filename (str, optional): The output filename for the parameters.
24+
25+
Returns:
26+
dict: The created parameters.
27+
"""
28+
return create_parameters(data, metadata, output_filename)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Integration tests for DayZ parameter detection."""
2+
3+
from sdv.datasets.demo import download_demo
4+
from sdv.multi_table import DayZSynthesizer
5+
6+
7+
class TestDayZSynthesizer:
8+
def test_create_parameters_end_to_end(self):
9+
"""Test the `create_parameters` method end to end."""
10+
# Setup
11+
data, metadata = download_demo(modality='multi_table', dataset_name='fake_hotels')
12+
13+
# Run
14+
parameters = DayZSynthesizer.create_parameters(data, metadata)
15+
16+
# Assert
17+
expected_results = {
18+
'DAYZ_SPEC_VERSION': 'V1',
19+
'tables': {
20+
'guests': {
21+
'num_rows': 658,
22+
'columns': {
23+
'guest_email': {'missing_values_proportion': 0.0},
24+
'hotel_id': {'missing_values_proportion': 0.0},
25+
'has_rewards': {
26+
'category_values': [False, True],
27+
'missing_values_proportion': 0.0,
28+
},
29+
'room_type': {
30+
'category_values': ['BASIC', 'DELUXE', 'SUITE'],
31+
'missing_values_proportion': 0.0,
32+
},
33+
'amenities_fee': {
34+
'num_decimal_digits': 2,
35+
'min_value': 0.0,
36+
'max_value': 46.64,
37+
'missing_values_proportion': 0.07598784194528875,
38+
},
39+
'checkin_date': {
40+
'start_timestamp': '03 Jan 2020',
41+
'end_timestamp': '05 Jan 2021',
42+
'missing_values_proportion': 0.0,
43+
},
44+
'checkout_date': {
45+
'start_timestamp': '04 Jan 2020',
46+
'end_timestamp': '07 Jan 2021',
47+
'missing_values_proportion': 0.04559270516717325,
48+
},
49+
'room_rate': {
50+
'num_decimal_digits': 2,
51+
'min_value': 48.33,
52+
'max_value': 481.61,
53+
'missing_values_proportion': 0.0,
54+
},
55+
'billing_address': {'missing_values_proportion': 0.0},
56+
'credit_card_number': {'missing_values_proportion': 0.0},
57+
},
58+
},
59+
'hotels': {
60+
'num_rows': 10,
61+
'columns': {
62+
'hotel_id': {'missing_values_proportion': 0.0},
63+
'city': {
64+
'category_values': [
65+
'Boston',
66+
'San Francisco',
67+
'New York City',
68+
'Austin',
69+
'Los Angeles',
70+
],
71+
'missing_values_proportion': 0.0,
72+
},
73+
'state': {
74+
'category_values': [
75+
'Massachusetts',
76+
'Massachuesetts',
77+
'California',
78+
'New York',
79+
'Texas',
80+
],
81+
'missing_values_proportion': 0.0,
82+
},
83+
'rating': {
84+
'num_decimal_digits': 1,
85+
'min_value': 3.7,
86+
'max_value': 4.9,
87+
'missing_values_proportion': 0.1,
88+
},
89+
'classification': {
90+
'category_values': ['RESORT', 'CHAIN', 'MOTEL'],
91+
'missing_values_proportion': 0.0,
92+
},
93+
},
94+
},
95+
},
96+
'relationships': {
97+
'["hotels", "guests", "hotel_id", "hotel_id"]': {
98+
'min_cardinality': 15,
99+
'max_cardinality': 137,
100+
}
101+
},
102+
}
103+
assert parameters == expected_results
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Integration tests for DayZ parameter detection."""
2+
3+
from sdv.datasets.demo import download_demo
4+
from sdv.single_table import DayZSynthesizer
5+
6+
7+
class TestDayZSynthesizer:
8+
def test_create_parameters_end_to_end(self):
9+
"""Test the `create_parameters` method end to end."""
10+
# Setup
11+
data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests')
12+
13+
# Run
14+
parameters = DayZSynthesizer.create_parameters(data, metadata)
15+
16+
# Assert
17+
expected_results = {
18+
'DAYZ_SPEC_VERSION': 'V1',
19+
'tables': {
20+
'fake_hotel_guests': {
21+
'num_rows': 500,
22+
'columns': {
23+
'guest_email': {'missing_values_proportion': 0.0},
24+
'has_rewards': {
25+
'category_values': [False, True],
26+
'missing_values_proportion': 0.0,
27+
},
28+
'room_type': {
29+
'category_values': ['BASIC', 'DELUXE', 'SUITE'],
30+
'missing_values_proportion': 0.0,
31+
},
32+
'amenities_fee': {
33+
'num_decimal_digits': 2,
34+
'min_value': 0.0,
35+
'max_value': 48.12,
36+
'missing_values_proportion': 0.09,
37+
},
38+
'checkin_date': {
39+
'start_timestamp': '05 Jan 2020',
40+
'end_timestamp': '07 Jan 2021',
41+
'missing_values_proportion': 0.0,
42+
},
43+
'checkout_date': {
44+
'start_timestamp': '07 Jan 2020',
45+
'end_timestamp': '08 Jan 2021',
46+
'missing_values_proportion': 0.04,
47+
},
48+
'room_rate': {
49+
'num_decimal_digits': 2,
50+
'min_value': 83.8,
51+
'max_value': 424.84,
52+
'missing_values_proportion': 0.0,
53+
},
54+
'billing_address': {'missing_values_proportion': 0.0},
55+
'credit_card_number': {'missing_values_proportion': 0.0},
56+
},
57+
},
58+
},
59+
}
60+
assert parameters == expected_results

0 commit comments

Comments
 (0)