2626class ResultsHandler (ABC ):
2727 """Abstract base class for handling results storage and retrieval."""
2828
29+ def __init__ (self , baseline_synthesizer = SYNTHESIZER_BASELINE ):
30+ # Allow overrides per modality while maintaining the historical default.
31+ self .baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE
32+
2933 @abstractmethod
3034 def list (self ):
3135 """List all runs in the results directory."""
@@ -63,7 +67,8 @@ def _compute_wins(self, result):
6367 result ['Win' ] = 0
6468 for dataset in datasets :
6569 score_baseline = result .loc [
66- (result ['Synthesizer' ] == SYNTHESIZER_BASELINE ) & (result ['Dataset' ] == dataset )
70+ (result ['Synthesizer' ] == self .baseline_synthesizer )
71+ & (result ['Dataset' ] == dataset )
6772 ]['Quality_Score' ].to_numpy ()
6873 if score_baseline .size == 0 :
6974 continue
@@ -88,7 +93,7 @@ def _get_summarize_table(self, folder_to_results, folder_infos):
8893 f' - # datasets: { folder_infos [folder ]["# datasets" ]} '
8994 f' - sdgym version: { folder_infos [folder ]["sdgym_version" ]} '
9095 )
91- results = results .loc [results ['Synthesizer' ] != SYNTHESIZER_BASELINE ]
96+ results = results .loc [results ['Synthesizer' ] != self . baseline_synthesizer ]
9297 column_data = results .groupby (['Synthesizer' ])['Win' ].sum ()
9398 columns .append ((date_obj , column_name , column_data ))
9499
@@ -111,9 +116,11 @@ def _get_column_name_infos(self, folder_to_results):
111116 continue
112117
113118 metainfo_info = self ._load_yaml_file (folder , yaml_files [0 ])
114- num_datasets = results .loc [
115- results ['Synthesizer' ] == SYNTHESIZER_BASELINE , 'Dataset'
116- ].nunique ()
119+ baseline_mask = results ['Synthesizer' ] == self .baseline_synthesizer
120+ if baseline_mask .any ():
121+ num_datasets = results .loc [baseline_mask , 'Dataset' ].nunique ()
122+ else :
123+ num_datasets = results ['Dataset' ].nunique ()
117124 folder_to_info [folder ] = {
118125 'date' : metainfo_info .get ('starting_date' )[:NUM_DIGITS_DATE ],
119126 'sdgym_version' : metainfo_info .get ('sdgym_version' ),
@@ -240,7 +247,8 @@ def all_runs_complete(self, folder_name):
240247class LocalResultsHandler (ResultsHandler ):
241248 """Results handler for local filesystem."""
242249
243- def __init__ (self , base_path ):
250+ def __init__ (self , base_path , baseline_synthesizer = SYNTHESIZER_BASELINE ):
251+ super ().__init__ (baseline_synthesizer = baseline_synthesizer )
244252 self .base_path = base_path
245253
246254 def list (self ):
@@ -295,7 +303,8 @@ def _load_yaml_file(self, folder_name, file_name):
295303class S3ResultsHandler (ResultsHandler ):
296304 """Results handler for AWS S3 storage."""
297305
298- def __init__ (self , path , s3_client ):
306+ def __init__ (self , path , s3_client , baseline_synthesizer = SYNTHESIZER_BASELINE ):
307+ super ().__init__ (baseline_synthesizer = baseline_synthesizer )
299308 self .s3_client = s3_client
300309 self .bucket_name = path .split ('/' )[2 ]
301310 self .prefix = '/' .join (path .split ('/' )[3 :]).rstrip ('/' ) + '/'
@@ -415,8 +424,8 @@ def _get_results(self, folder_name, file_names):
415424 for file_name in file_names :
416425 s3_key = f'{ self .prefix } { folder_name } /{ file_name } '
417426 response = self .s3_client .get_object (Bucket = self .bucket_name , Key = s3_key )
418- df = pd .read_csv (io .BytesIO (response ['Body' ].read ()))
419- results .append (df )
427+ result_df = pd .read_csv (io .BytesIO (response ['Body' ].read ()))
428+ results .append (result_df )
420429
421430 return results
422431
0 commit comments