@@ -18,12 +18,8 @@ def _validate_local_path(path):
1818 raise ValueError (f"The provided path '{ path } ' is not a valid local directory." )
1919
2020
21- _FOLDER_BY_MODALITY = {
22- 'single_table' : 'single_table' ,
23- 'multi_table' : 'multi_table' ,
24- }
25-
2621_BASELINE_BY_MODALITY = {
22+ 'single_table' : SYNTHESIZER_BASELINE ,
2723 'multi_table' : 'MultiTableUniformSynthesizer' ,
2824}
2925
@@ -38,45 +34,47 @@ def _resolve_effective_path(path, modality):
3834 if not modality :
3935 return path
4036
41- folder = _FOLDER_BY_MODALITY .get (modality )
42- if folder is None :
43- valid = ', ' .join (sorted (_FOLDER_BY_MODALITY ))
37+ if modality not in ('single_table' , 'multi_table' ):
38+ valid = ', ' .join (sorted (('single_table' , 'multi_table' )))
4439 raise ValueError (f'Invalid modality "{ modality } ". Valid options are: { valid } .' )
4540
4641 # Avoid double-appending if already included
47- if str (path ).rstrip ('/' ).endswith (('/' + folder , folder )):
42+ if str (path ).rstrip ('/' ).endswith (('/' + modality , modality )):
4843 return path
4944
5045 if is_s3_path (path ):
51- path = path .rstrip ('/' ) + '/' + folder
46+ path = path .rstrip ('/' ) + '/' + modality
5247 return path
5348
54- return os .path .join (path , folder )
49+ return os .path .join (path , modality )
5550
5651
5752class ResultsExplorer :
5853 """Explorer for SDGym benchmark results, supporting both local and S3 storage."""
5954
55+ def _create_results_handler (self , original_path , effective_path , baseline_synthesizer ):
56+ """Create the appropriate results handler for local or S3 storage."""
57+ if is_s3_path (original_path ):
58+ # Use original path to obtain client (keeps backwards compatibility),
59+ # but handler should operate on the modality-specific effective path.
60+ s3_client = _get_s3_client (
61+ original_path , self .aws_access_key_id , self .aws_secret_access_key
62+ )
63+ return S3ResultsHandler (
64+ effective_path , s3_client , baseline_synthesizer = baseline_synthesizer
65+ )
66+
67+ _validate_local_path (effective_path )
68+ return LocalResultsHandler (effective_path , baseline_synthesizer = baseline_synthesizer )
69+
6070 def __init__ (self , path , modality , aws_access_key_id = None , aws_secret_access_key = None ):
6171 self .path = path
62- self .modality = modality
72+ self .modality = modality . lower ()
6373 self .aws_access_key_id = aws_access_key_id
6474 self .aws_secret_access_key = aws_secret_access_key
65-
6675 baseline_synthesizer = _get_baseline_synthesizer (modality )
6776 effective_path = _resolve_effective_path (path , modality )
68- if is_s3_path (path ):
69- # Use original path to obtain client (keeps backwards compatibility),
70- # but handler should operate on the modality-specific effective path.
71- s3_client = _get_s3_client (path , aws_access_key_id , aws_secret_access_key )
72- self ._handler = S3ResultsHandler (
73- effective_path , s3_client , baseline_synthesizer = baseline_synthesizer
74- )
75- else :
76- _validate_local_path (effective_path )
77- self ._handler = LocalResultsHandler (
78- effective_path , baseline_synthesizer = baseline_synthesizer
79- )
77+ self ._handler = self ._create_results_handler (path , effective_path , baseline_synthesizer )
8078
8179 def list (self ):
8280 """List all runs available in the results directory."""
@@ -125,7 +123,7 @@ def load_real_data(self, dataset_name):
125123 )
126124
127125 data , _ = load_dataset (
128- modality = self .modality or 'single_table' ,
126+ modality = self .modality ,
129127 dataset = dataset_name ,
130128 aws_access_key_id = self .aws_access_key_id ,
131129 aws_secret_access_key = self .aws_secret_access_key ,
0 commit comments