Skip to content

Commit fd1b30a

Browse files
authored
Add summarize function to SDGymResultsExplorer class (#424)
1 parent b1a5bd9 commit fd1b30a

File tree

13 files changed

+590
-1
lines changed

13 files changed

+590
-1
lines changed

sdgym/sdgym_result_explorer/result_explorer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,17 @@ def load_real_data(self, dataset_name):
8181
aws_secret=self.aws_secret_access_key,
8282
)
8383
return data
84+
85+
def summarize(self, folder_name):
86+
"""Summarize the results in the specified folder.
87+
88+
Args:
89+
folder_name (str):
90+
The name of the results folder to summarize.
91+
92+
Returns:
93+
tuple (pd.DataFrame, pd.DataFrame):
94+
- A summary DataFrame with the number of Wins per synthesizer.
95+
- A DataFrame with the results of the benchmark for the specified folder.
96+
"""
97+
return self._handler.summarize(folder_name)

sdgym/sdgym_result_explorer/result_handler.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
"""Handlers for managing SDGym benchmark results, supporting both local and S3 storage."""
22

33
import io
4+
import operator
45
import os
56
import pickle
67
from abc import ABC, abstractmethod
8+
from datetime import datetime
79

810
import pandas as pd
11+
import yaml
912
from botocore.exceptions import ClientError
1013

14+
SYNTHESIZER_BASELINE = 'GaussianCopulaSynthesizer'
15+
RESULTS_FOLDER_PREFIX = 'SDGym_results_'
16+
RUN_ID_PREFIX = 'run_'
17+
RESULTS_FILE_PREFIX = 'results_'
18+
1119

1220
class ResultsHandler(ABC):
1321
"""Abstract base class for handling results storage and retrieval."""
@@ -32,6 +40,121 @@ def load_synthetic_data(self, file_path):
3240
"""Load synthetic data from a file."""
3341
pass
3442

43+
@abstractmethod
44+
def _load_yaml_file(self, folder_name, file_name):
45+
"""Load a YAML file from the results folder."""
46+
pass
47+
48+
def _compute_wins(self, result):
49+
synthesizers = result['Synthesizer'].unique()
50+
datasets = result['Dataset'].unique()
51+
result['Win'] = 0
52+
for dataset in datasets:
53+
score_baseline = result.loc[
54+
(result['Synthesizer'] == SYNTHESIZER_BASELINE) & (result['Dataset'] == dataset)
55+
]['Quality_Score'].to_numpy()
56+
if score_baseline.size == 0:
57+
continue
58+
59+
for synthesizer in synthesizers:
60+
loc_synthesizer = (result['Synthesizer'] == synthesizer) & (
61+
result['Dataset'] == dataset
62+
)
63+
score_synthesizer = result.loc[loc_synthesizer]['Quality_Score'].to_numpy()
64+
result.loc[loc_synthesizer, 'Win'] = (score_synthesizer > score_baseline).astype(
65+
int
66+
)
67+
68+
def _get_summarize_table(self, folder_to_results, folder_infos):
69+
"""Create a summary table from the results."""
70+
columns = []
71+
for folder, results in folder_to_results.items():
72+
date_str = folder_infos[folder]['date']
73+
date_obj = datetime.strptime(date_str, '%m_%d_%Y')
74+
column_name = (
75+
f'{date_str}'
76+
f' - # datasets: {folder_infos[folder]["# datasets"]}'
77+
f' - sdgym version: {folder_infos[folder]["sdgym_version"]}'
78+
)
79+
results = results.loc[results['Synthesizer'] != SYNTHESIZER_BASELINE]
80+
column_data = results.groupby(['Synthesizer'])['Win'].sum()
81+
columns.append((date_obj, column_name, column_data))
82+
83+
columns.sort(key=operator.itemgetter(0))
84+
summarized_results = pd.DataFrame()
85+
for _, column_name, column_data in reversed(columns):
86+
summarized_results[column_name] = column_data
87+
88+
summarized_results = summarized_results.fillna('-')
89+
return summarized_results
90+
91+
def _get_column_name_infos(self, folder_to_results):
92+
folder_to_info = {}
93+
for folder, results in folder_to_results.items():
94+
yaml_files = self._get_results_files(folder, prefix=RUN_ID_PREFIX, suffix='.yaml')
95+
if not yaml_files:
96+
continue
97+
98+
run_id_info = self._load_yaml_file(folder, yaml_files[0])
99+
num_datasets = results.loc[
100+
results['Synthesizer'] == SYNTHESIZER_BASELINE, 'Dataset'
101+
].nunique()
102+
folder_to_info[folder] = {
103+
'date': run_id_info.get('starting_date')[:10], # Extract only the YYYY-MM-DD
104+
'sdgym_version': run_id_info.get('sdgym_version'),
105+
'# datasets': num_datasets,
106+
}
107+
108+
return folder_to_info
109+
110+
def _process_results(self, results):
111+
"""Process results to ensure they are unique and each dataset has all synthesizers."""
112+
aggregated_results = pd.concat(results, ignore_index=True)
113+
aggregated_results = aggregated_results.drop_duplicates(subset=['Dataset', 'Synthesizer'])
114+
all_synthesizers = aggregated_results['Synthesizer'].unique()
115+
dataset_synth_counts = aggregated_results.groupby('Dataset')['Synthesizer'].nunique()
116+
valid_datasets = dataset_synth_counts[dataset_synth_counts == len(all_synthesizers)].index
117+
filtered_results = aggregated_results[aggregated_results['Dataset'].isin(valid_datasets)]
118+
if filtered_results.empty:
119+
raise ValueError(
120+
'There is no dataset that has been run by all synthesizers. Cannot '
121+
'summarize results.'
122+
)
123+
124+
return filtered_results.reset_index(drop=True)
125+
126+
def summarize(self, folder_name):
127+
"""Summarize the results in the specified folder."""
128+
all_folders = [f for f in self.list() if f.startswith(RESULTS_FOLDER_PREFIX)]
129+
if folder_name not in all_folders:
130+
raise ValueError(f'Folder "{folder_name}" does not exist in the results directory.')
131+
132+
date = pd.to_datetime(folder_name[-10:], format='%m_%d_%Y')
133+
folder_to_results = {}
134+
for folder in all_folders:
135+
folder_date = pd.to_datetime(folder[len(RESULTS_FOLDER_PREFIX) :], format='%m_%d_%Y')
136+
if folder_date > date:
137+
continue
138+
139+
result_filenames = self._get_results_files(
140+
folder, prefix=RESULTS_FILE_PREFIX, suffix='.csv'
141+
)
142+
if not result_filenames:
143+
continue
144+
145+
results = self._get_results(folder, result_filenames)
146+
if not results:
147+
continue
148+
149+
aggregated_results = self._process_results(results)
150+
self._compute_wins(aggregated_results)
151+
folder_to_results[folder] = aggregated_results
152+
folder_infos = self._get_column_name_infos(folder_to_results)
153+
154+
summarized_table = self._get_summarize_table(folder_to_results, folder_infos)
155+
156+
return summarized_table, folder_to_results[folder_name]
157+
35158

36159
class LocalResultsHandler(ResultsHandler):
37160
"""Results handler for local filesystem."""
@@ -65,6 +188,24 @@ def load_synthetic_data(self, file_path):
65188
"""Load synthetic data from a CSV file."""
66189
return pd.read_csv(os.path.join(self.base_path, file_path))
67190

191+
def _get_results_files(self, folder_name, prefix, suffix):
192+
return [
193+
f
194+
for f in os.listdir(os.path.join(self.base_path, folder_name))
195+
if f.endswith(suffix) and f.startswith(prefix)
196+
]
197+
198+
def _get_results(self, folder_name, file_names):
199+
return [
200+
pd.read_csv(os.path.join(self.base_path, folder_name, file_name))
201+
for file_name in file_names
202+
]
203+
204+
def _load_yaml_file(self, folder_name, file_name):
205+
file_path = os.path.join(self.base_path, folder_name, file_name)
206+
with open(file_path, 'r') as f:
207+
return yaml.safe_load(f)
208+
68209

69210
class S3ResultsHandler(ResultsHandler):
70211
"""Results handler for AWS S3 storage."""
@@ -103,3 +244,30 @@ def load_synthetic_data(self, file_path):
103244
Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}'
104245
)
105246
return pd.read_csv(io.BytesIO(response['Body'].read()))
247+
248+
def _get_results_files(self, folder_name, prefix, suffix):
249+
s3_prefix = f'{self.prefix}{folder_name}/'
250+
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=s3_prefix)
251+
if 'Contents' not in response:
252+
return []
253+
254+
return [
255+
obj['Key'].split('/')[-1]
256+
for obj in response['Contents']
257+
if obj['Key'].startswith(s3_prefix + prefix) and obj['Key'].endswith(suffix)
258+
]
259+
260+
def _get_results(self, folder_name, file_names):
261+
results = []
262+
for file_name in file_names:
263+
s3_key = f'{self.prefix}{folder_name}/{file_name}'
264+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
265+
df = pd.read_csv(io.BytesIO(response['Body'].read()))
266+
results.append(df)
267+
268+
return results
269+
270+
def _load_yaml_file(self, folder_name, file_name):
271+
s3_key = f'{self.prefix}{folder_name}/{file_name}'
272+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
273+
return yaml.safe_load(response['Body'])
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Quality_Score,NewRowSynthesis,Win
2+
CTGANSynthesizer,adult,3.907448,2342.444013,128.172054,3.827279,1.759033,25.273611,0.8714924623776077,0.999,True
3+
CTGANSynthesizer,alarm,4.520128,1487.424623,55.766691,7.383351,2.123863,60.076879000000005,0.9335235735735736,0.765,False
4+
CTGANSynthesizer,census,98.165608,122285.653783,3570.967303,73.361215,41.595805,467.547645,0.8973692159020539,0.962,True
5+
CTGANSynthesizer,child,3.200128,1003.500687,32.6008,4.282823,1.181363,28.008788000000003,0.9131484210526316,0.719,False
6+
CTGANSynthesizer,expedia_hotel_logs,0.200128,152.122757,15.751031,5.169425,0.736054,24.029036,0.6936163099206261,1.0,False
7+
CTGANSynthesizer,insurance,3.340128,1246.284132,46.761978,5.615351,1.554094,38.10747,0.9059027777777776,0.913,False
8+
CTGANSynthesizer,intrusion,162.039016,82807.286322,2657.160885,17.832114,74.913712,266.39056400000004,0.7791883833282515,0.975,True
9+
CTGANSynthesizer,news,18.712096,3693.241304,272.001799,2.231116,7.262025,105.143338,0.959474669126202,1.0,True
10+
CTGANSynthesizer,covtype,255.645408,100610.359588,3483.111359,6.744468,89.162266,321.334703,0.9406897050649742,1.0,True
11+
CopulaGANSynthesizer,adult,3.907448,2314.982089,129.07933,3.85024,2.090281,24.57947,0.8948133644497218,0.996,True
12+
CopulaGANSynthesizer,alarm,4.520128,1471.227566,60.312497,7.384683,2.092101,59.962778,0.912150900900901,0.777,False
13+
CopulaGANSynthesizer,census,98.165608,121247.643918,3611.802882,73.397392,48.615023,463.407474,0.8844791502645143,0.919,False
14+
CopulaGANSynthesizer,child,3.200128,983.393637,35.825856,4.283679,1.146471,27.611243,0.9245040789473684,0.748,False
15+
CopulaGANSynthesizer,expedia_hotel_logs,0.200128,153.118696,15.722396,5.153251,0.77662,24.254977,0.6274705731625682,1.0,False
16+
CopulaGANSynthesizer,insurance,3.340128,1258.885573,50.114157,5.616403,1.552372,38.96672,0.8901035612535613,0.929,False
17+
CopulaGANSynthesizer,intrusion,162.039016,,,,,,,,False
18+
CopulaGANSynthesizer,news,18.712096,4022.354968,783.27276,2.653436,11.194702,107.250757,0.949910184852131,1.0,True
19+
CopulaGANSynthesizer,covtype,255.645408,100300.165465,3739.289605,7.057369,171.153944,307.30554099999995,0.7664487459476463,1.0,True
20+
FastMLPreset,adult,3.907448,7.370428,40.996445,0.169626,1.239985,25.325739,0.7941277750188948,1.0,False
21+
FastMLPreset,alarm,4.520128,23.086832,61.416743,0.298763,2.047489,60.192136000000005,0.9365205705705704,0.939,False
22+
FastMLPreset,census,98.165608,214.864156,1009.269257,0.358662,31.746223,474.334129,0.8158035999746163,1.0,False
23+
FastMLPreset,child,3.200128,9.463954,33.318735,0.19316,1.124241,27.838249,0.9142255263157896,0.969,False
24+
FastMLPreset,expedia_hotel_logs,0.200128,5.831513,2.845928,0.291425,0.617056,24.560681,0.7143156155762755,1.0,True
25+
FastMLPreset,insurance,3.340128,16.409016,44.900113,0.237351,1.420296,38.827266,0.9062499287749288,0.998,False
26+
FastMLPreset,intrusion,162.039016,99.881763,1665.472329,0.333307,41.30665,277.45795699999996,0.8395414299423576,1.0,True
27+
FastMLPreset,news,18.712096,8.493529,192.931325,0.450737,4.957063,104.682417,0.9333446855644932,1.0,True
28+
FastMLPreset,covtype,255.645408,119.831842,2625.796359,0.422665,68.566325,313.40950399999997,0.9534631256047564,1.0,True
29+
GaussianCopulaSynthesizer,adult,3.907448,18.016667,41.03898,0.16917,1.834246,25.891107,0.8460627797682108,1.0,False
30+
GaussianCopulaSynthesizer,alarm,4.520128,37.043426,61.430106,0.300693,2.790147,61.602110999999994,0.9824800675675674,0.965,False
31+
GaussianCopulaSynthesizer,census,98.165608,369.925453,1009.282085,0.360762,44.007031,476.547258,0.8894152679182734,1.0,False
32+
GaussianCopulaSynthesizer,child,3.200128,16.398928,33.328972,0.193512,1.504567,27.775566,0.9680398684210528,0.982,False
33+
GaussianCopulaSynthesizer,expedia_hotel_logs,0.200128,10.218074,2.854699,0.292641,0.63049,23.652367,0.6967257630420702,1.0,False
34+
GaussianCopulaSynthesizer,insurance,3.340128,25.745081,44.890834,0.238681,1.951241,38.162188,0.970242663817664,0.997,False
35+
GaussianCopulaSynthesizer,intrusion,162.039016,278.819512,1665.478969,0.335143,59.612199,272.88507,0.7256691671182501,1.0,False
36+
GaussianCopulaSynthesizer,news,18.712096,49.950842,192.958997,0.454225,7.793331,101.959868,0.8898934188617345,1.0,False
37+
GaussianCopulaSynthesizer,covtype,255.645408,373.237209,2625.811,0.425605,111.684247,298.469962,0.6326398468610831,1.0,False
38+
TVAESynthesizer,adult,3.907448,995.675674,128.20047,0.937158,1.338342,25.471986,0.8848574626770297,1.0,True
39+
TVAESynthesizer,alarm,4.520128,468.996626,55.745239,0.856925,1.683701,62.165522,0.9722823573573576,0.514,False
40+
TVAESynthesizer,census,98.165608,26245.855674,3570.953555,5.225792,30.817869,485.354262,0.9138488725217512,0.957,True
41+
TVAESynthesizer,child,3.200128,457.011626,32.616572,0.692508,0.965435,28.631832000000003,0.9612931578947368,0.743,False
42+
TVAESynthesizer,expedia_hotel_logs,0.200128,39.082162,15.760265,0.843527,0.735609,26.927474,0.6959875048061848,0.932,False
43+
TVAESynthesizer,insurance,3.340128,446.521891,46.766004,0.767688,1.303873,40.311753,0.9201681623931623,0.941,False
44+
TVAESynthesizer,intrusion,162.039016,38311.652012,2657.160525,8.049489,63.251776,278.383287,0.8711158830516244,0.962,True
45+
TVAESynthesizer,news,18.712096,2906.38651,271.971459,1.852811,5.997912,97.436519,0.9152536710657996,1.0,True
46+
TVAESynthesizer,covtype,255.645408,52681.625463,3483.13211,9.56187,68.155652,314.825924,0.940890148606724,1.0,True
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
completed_date: 04_05_2024 16:00:12
2+
jobs: None
3+
run_id: run_04_05_2024_1
4+
sdgym_version: 0.7.0
5+
sdv_version: 1.24.0
6+
starting_date: 04_05_2024 15:56:03
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Quality_Score,NewRowSynthesis,Win
2+
CTGANSynthesizer,adult,3.907448,2335.220195,128.165694,3.827328,1.751409,25.020344,0.8500066457227371,1.0,True
3+
CTGANSynthesizer,alarm,4.520128,1459.873185,55.756612,7.383406,2.068689,59.54888,0.9420208333333334,0.751,False
4+
CTGANSynthesizer,census,98.165608,122629.33064,3570.964094,73.361264,41.049132,458.13628,0.8841361114565804,0.995,False
5+
CTGANSynthesizer,child,3.200128,979.617808,32.617664,4.282878,1.142984,27.238076,0.9019419078947368,0.78,False
6+
CTGANSynthesizer,expedia_hotel_logs,0.200128,150.563925,15.91442,5.169491,0.748102,24.14625,0.6880639845798507,1.0,False
7+
CTGANSynthesizer,insurance,3.340128,1212.558941,46.762523,5.615406,1.511614,37.646447,0.8864513532763532,0.924,False
8+
CTGANSynthesizer,intrusion,162.039016,82280.437439,2657.139259,17.832163,72.566051,265.73881700000004,0.7818205075028868,0.979,True
9+
CTGANSynthesizer,news,18.712096,3544.486617,271.973699,2.231165,6.736061,102.504686,0.962816373010111,1.0,True
10+
CTGANSynthesizer,covtype,255.645408,95090.524928,3483.092798,6.744517,81.259212,316.717502,0.9392885541297548,1.0,True
11+
CopulaGANSynthesizer,adult,3.907448,2442.880482,129.075719,3.850393,2.192589,25.681516,0.9022687633936166,1.0,True
12+
CopulaGANSynthesizer,alarm,4.520128,1512.218437,60.27888,7.384842,2.058639,62.319537,0.8975387762762764,0.814,False
13+
CopulaGANSynthesizer,census,98.165608,121961.180374,3611.784183,73.397545,50.029892,475.011129,0.8976299557232021,0.936,True
14+
CopulaGANSynthesizer,child,3.200128,1044.897489,35.827332,4.283838,1.169998,28.161134,0.9201572368421052,0.718,False
15+
CopulaGANSynthesizer,expedia_hotel_logs,0.200128,164.466943,15.884971,5.153421,0.86762,25.326735,0.6164453240138844,1.0,False
16+
CopulaGANSynthesizer,insurance,3.340128,1294.988253,50.112184,5.616562,1.539449,38.978003,0.912103774928775,0.912,False
17+
CopulaGANSynthesizer,intrusion,162.039016,,,,,,,,False
18+
CopulaGANSynthesizer,news,18.712096,4187.002935,783.287077,2.653589,11.598527,107.001778,0.9484535611962412,1.0,True
19+
CopulaGANSynthesizer,covtype,255.645408,98195.677411,3739.273639,7.057522,187.937031,320.969169,0.7660242003877997,1.0,True
20+
GaussianCopulaSynthesizer,adult,3.907448,18.207913,41.026812,0.169494,1.825013,25.622773,0.8460627797682108,1.0,False
21+
GaussianCopulaSynthesizer,alarm,4.520128,37.699565,61.443354,0.301017,2.739793,61.26025,0.9824800675675674,0.963,False
22+
GaussianCopulaSynthesizer,census,98.165608,378.84696,1009.291104,0.361086,44.221692,466.572605,0.8894152679182734,1.0,False
23+
GaussianCopulaSynthesizer,child,3.200128,16.712313,33.335665,0.193836,1.502726,28.425445000000003,0.9680398684210528,0.979,False
24+
GaussianCopulaSynthesizer,expedia_hotel_logs,0.200128,10.89391,3.021129,0.292982,0.655214,24.683722,0.6967257630420702,1.0,False
25+
GaussianCopulaSynthesizer,insurance,3.340128,27.583096,44.920821,0.239005,1.988513,38.589909,0.970242663817664,0.997,False
26+
GaussianCopulaSynthesizer,intrusion,162.039016,287.947423,1665.469206,0.335467,62.790978,275.274062,0.7256691671182501,1.0,False
27+
GaussianCopulaSynthesizer,news,18.712096,51.417232,192.939966,0.454549,7.996991,106.262719,0.8898934188617345,1.0,False
28+
GaussianCopulaSynthesizer,covtype,255.645408,382.860954,2625.79807,0.425929,116.524616,302.983582,0.6326398468610831,1.0,False
29+
TVAESynthesizer,adult,3.907448,968.556432,128.19177,0.937476,1.375366,25.51443,0.8906560246902444,0.997,True
30+
TVAESynthesizer,alarm,4.520128,448.360238,55.775492,0.857249,1.660257,60.414942,0.9722228603603604,0.484,False
31+
TVAESynthesizer,census,98.165608,26935.849083,3570.955421,5.22611,32.642758,488.297111,0.9183024856637144,0.891,True
32+
TVAESynthesizer,child,3.200128,401.538141,32.619298,0.692832,0.936026,28.233848,0.9569536184210524,0.712,False
33+
TVAESynthesizer,expedia_hotel_logs,0.200128,38.050896,15.894987,0.843862,0.703003,25.770061,0.6853025665680748,0.917,False
34+
TVAESynthesizer,insurance,3.340128,415.711596,46.754888,0.768012,1.245651,38.6992,0.9259373931623932,0.921,False
35+
TVAESynthesizer,intrusion,162.039016,32531.630601,2657.12625,8.049807,70.083053,295.088169,0.8670193556776293,0.978,True
36+
TVAESynthesizer,news,18.712096,2997.100946,271.963851,1.853129,6.305529,98.561018,0.9311909881744532,1.0,True
37+
TVAESynthesizer,covtype,255.645408,53751.366715,3483.099878,9.562188,70.7207,336.707852,0.933603114374824,1.0,True
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
completed_date: 05_10_2024 16:00:12
2+
jobs: None
3+
run_id: run_05_10_2024_1
4+
sdgym_version: 0.8.0
5+
sdv_version: 1.24.0
6+
starting_date: 05_10_2024 15:56:03

0 commit comments

Comments
 (0)