1616
1717from sdgym .result_explorer .result_explorer import ResultsExplorer
1818from sdgym .result_writer import LocalResultsWriter
19- from sdgym .run_benchmark .utils import OUTPUT_DESTINATION_AWS , get_df_to_plot
19+ from sdgym .run_benchmark .utils import OUTPUT_DESTINATION_AWS , _parse_args , get_df_to_plot
2020from sdgym .s3 import S3_REGION , parse_s3_path
2121
2222LOGGER = logging .getLogger (__name__ )
@@ -45,17 +45,21 @@ def get_latest_run_from_file(s3_client, bucket, key):
4545 raise RuntimeError (f'Failed to read { key } from S3: { e } ' )
4646
4747
48- def write_uploaded_marker (s3_client , bucket , prefix , folder_name ):
48+ def write_uploaded_marker (s3_client , bucket , prefix , folder_name , modality = 'single_table' ):
4949 """Write a marker file to indicate that the upload is complete."""
5050 s3_client .put_object (
51- Bucket = bucket , Key = f'{ prefix } { folder_name } /upload_complete.marker' , Body = b'Upload complete'
51+ Bucket = bucket ,
52+ Key = f'{ prefix } { modality } /{ folder_name } /upload_complete.marker' ,
53+ Body = b'Upload complete' ,
5254 )
5355
5456
55- def upload_already_done (s3_client , bucket , prefix , folder_name ):
57+ def upload_already_done (s3_client , bucket , prefix , folder_name , modality = 'single_table' ):
5658 """Check if the upload has already been done by looking for the marker file."""
5759 try :
58- s3_client .head_object (Bucket = bucket , Key = f'{ prefix } { folder_name } /upload_complete.marker' )
60+ s3_client .head_object (
61+ Bucket = bucket , Key = f'{ prefix } { modality } /{ folder_name } /upload_complete.marker'
62+ )
5963 return True
6064 except ClientError as e :
6165 if e .response ['Error' ]['Code' ] == '404' :
@@ -64,7 +68,9 @@ def upload_already_done(s3_client, bucket, prefix, folder_name):
6468 raise
6569
6670
67- def get_result_folder_name_and_s3_vars (aws_access_key_id , aws_secret_access_key ):
71+ def get_result_folder_name_and_s3_vars (
72+ aws_access_key_id , aws_secret_access_key , modality = 'single_table'
73+ ):
6874 """Get the result folder name and S3 client variables."""
6975 bucket , prefix = parse_s3_path (OUTPUT_DESTINATION_AWS )
7076 s3_client = boto3 .client (
@@ -73,7 +79,9 @@ def get_result_folder_name_and_s3_vars(aws_access_key_id, aws_secret_access_key)
7379 aws_secret_access_key = aws_secret_access_key ,
7480 region_name = S3_REGION ,
7581 )
76- folder_infos = get_latest_run_from_file (s3_client , bucket , f'{ prefix } _BENCHMARK_DATES.json' )
82+ folder_infos = get_latest_run_from_file (
83+ s3_client , bucket , f'{ prefix } { modality } /_BENCHMARK_DATES.json'
84+ )
7785
7886 return folder_infos , s3_client , bucket , prefix
7987
@@ -109,14 +117,21 @@ def upload_to_drive(file_path, file_id):
109117
110118
111119def upload_results (
112- aws_access_key_id , aws_secret_access_key , folder_infos , s3_client , bucket , prefix , github_env
120+ aws_access_key_id ,
121+ aws_secret_access_key ,
122+ folder_infos ,
123+ s3_client ,
124+ bucket ,
125+ prefix ,
126+ github_env ,
127+ modality = 'single_table' ,
113128):
114129 """Upload benchmark results to S3, GDrive, and save locally."""
115130 folder_name = folder_infos ['folder_name' ]
116131 run_date = folder_infos ['date' ]
117132 result_explorer = ResultsExplorer (
118133 OUTPUT_DESTINATION_AWS ,
119- modality = 'single_table' ,
134+ modality = modality ,
120135 aws_access_key_id = aws_access_key_id ,
121136 aws_secret_access_key = aws_secret_access_key ,
122137 )
@@ -145,7 +160,7 @@ def upload_results(
145160
146161 Path (local_export_dir ).mkdir (parents = True , exist_ok = True )
147162 local_file_path = str (Path (local_export_dir ) / RESULT_FILENAME )
148- s3_key = f'{ prefix } { RESULT_FILENAME } '
163+ s3_key = f'{ prefix } { modality } / { RESULT_FILENAME } '
149164 s3_client .download_file (bucket , s3_key , local_file_path )
150165 datas = {
151166 'Wins' : summary ,
@@ -155,20 +170,22 @@ def upload_results(
155170 local_results_writer .write_xlsx (datas , local_file_path )
156171 upload_to_drive ((local_file_path ), SDGYM_FILE_ID )
157172 s3_client .upload_file (local_file_path , bucket , s3_key )
158- write_uploaded_marker (s3_client , bucket , prefix , folder_name )
173+ write_uploaded_marker (s3_client , bucket , prefix , folder_name , modality = modality )
159174 if temp_dir :
160175 shutil .rmtree (temp_dir )
161176
162177
163178def main ():
164179 """Main function to upload benchmark results."""
180+ args = _parse_args ()
181+ modality = args .modality
165182 aws_access_key_id = os .getenv ('AWS_ACCESS_KEY_ID' )
166183 aws_secret_access_key = os .getenv ('AWS_SECRET_ACCESS_KEY' )
167184 folder_infos , s3_client , bucket , prefix = get_result_folder_name_and_s3_vars (
168- aws_access_key_id , aws_secret_access_key
185+ aws_access_key_id , aws_secret_access_key , modality = modality
169186 )
170187 github_env = os .getenv ('GITHUB_ENV' )
171- if upload_already_done (s3_client , bucket , prefix , folder_infos ['folder_name' ]):
188+ if upload_already_done (s3_client , bucket , prefix , folder_infos ['folder_name' ], modality ):
172189 LOGGER .warning ('Benchmark results have already been uploaded. Exiting.' )
173190 if github_env :
174191 with open (github_env , 'a' ) as env_file :
@@ -184,6 +201,7 @@ def main():
184201 bucket ,
185202 prefix ,
186203 github_env ,
204+ modality ,
187205 )
188206
189207
0 commit comments