2222class ResultsHandler (ABC ):
2323 """Abstract base class for handling results storage and retrieval."""
2424
25+ def __init__ (self , baseline_synthesizer = SYNTHESIZER_BASELINE ):
26+ # Allow overrides per modality while maintaining the historical default.
27+ self .baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE
28+
2529 @abstractmethod
2630 def list (self ):
2731 """List all runs in the results directory."""
@@ -59,7 +63,8 @@ def _compute_wins(self, result):
5963 result ['Win' ] = 0
6064 for dataset in datasets :
6165 score_baseline = result .loc [
62- (result ['Synthesizer' ] == SYNTHESIZER_BASELINE ) & (result ['Dataset' ] == dataset )
66+ (result ['Synthesizer' ] == self .baseline_synthesizer )
67+ & (result ['Dataset' ] == dataset )
6368 ]['Quality_Score' ].to_numpy ()
6469 if score_baseline .size == 0 :
6570 continue
@@ -84,7 +89,7 @@ def _get_summarize_table(self, folder_to_results, folder_infos):
8489 f' - # datasets: { folder_infos [folder ]["# datasets" ]} '
8590 f' - sdgym version: { folder_infos [folder ]["sdgym_version" ]} '
8691 )
87- results = results .loc [results ['Synthesizer' ] != SYNTHESIZER_BASELINE ]
92+ results = results .loc [results ['Synthesizer' ] != self . baseline_synthesizer ]
8893 column_data = results .groupby (['Synthesizer' ])['Win' ].sum ()
8994 columns .append ((date_obj , column_name , column_data ))
9095
@@ -107,9 +112,11 @@ def _get_column_name_infos(self, folder_to_results):
107112 continue
108113
109114 metainfo_info = self ._load_yaml_file (folder , yaml_files [0 ])
110- num_datasets = results .loc [
111- results ['Synthesizer' ] == SYNTHESIZER_BASELINE , 'Dataset'
112- ].nunique ()
115+ baseline_mask = results ['Synthesizer' ] == self .baseline_synthesizer
116+ if baseline_mask .any ():
117+ num_datasets = results .loc [baseline_mask , 'Dataset' ].nunique ()
118+ else :
119+ num_datasets = results ['Dataset' ].nunique ()
113120 folder_to_info [folder ] = {
114121 'date' : metainfo_info .get ('starting_date' )[:NUM_DIGITS_DATE ],
115122 'sdgym_version' : metainfo_info .get ('sdgym_version' ),
@@ -236,7 +243,8 @@ def all_runs_complete(self, folder_name):
236243class LocalResultsHandler (ResultsHandler ):
237244 """Results handler for local filesystem."""
238245
239- def __init__ (self , base_path ):
246+ def __init__ (self , base_path , baseline_synthesizer = SYNTHESIZER_BASELINE ):
247+ super ().__init__ (baseline_synthesizer = baseline_synthesizer )
240248 self .base_path = base_path
241249
242250 def list (self ):
@@ -287,7 +295,8 @@ def _load_yaml_file(self, folder_name, file_name):
287295class S3ResultsHandler (ResultsHandler ):
288296 """Results handler for AWS S3 storage."""
289297
290- def __init__ (self , path , s3_client ):
298+ def __init__ (self , path , s3_client , baseline_synthesizer = SYNTHESIZER_BASELINE ):
299+ super ().__init__ (baseline_synthesizer = baseline_synthesizer )
291300 self .s3_client = s3_client
292301 self .bucket_name = path .split ('/' )[2 ]
293302 self .prefix = '/' .join (path .split ('/' )[3 :]).rstrip ('/' ) + '/'
@@ -396,8 +405,8 @@ def _get_results(self, folder_name, file_names):
396405 for file_name in file_names :
397406 s3_key = f'{ self .prefix } { folder_name } /{ file_name } '
398407 response = self .s3_client .get_object (Bucket = self .bucket_name , Key = s3_key )
399- df = pd .read_csv (io .BytesIO (response ['Body' ].read ()))
400- results .append (df )
408+ result_df = pd .read_csv (io .BytesIO (response ['Body' ].read ()))
409+ results .append (result_df )
401410
402411 return results
403412
0 commit comments