Skip to content

Commit 716bf12

Browse files
committed
make variable name consistent
1 parent 04fdd1c commit 716bf12

File tree

15 files changed

+163
-97
lines changed

15 files changed

+163
-97
lines changed

sdgym/_run_benchmark/run_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88

99
def main():
10-
aws_key = os.getenv('AWS_ACCESS_KEY_ID')
11-
aws_secret = os.getenv('AWS_SECRET_ACCESS_KEY')
10+
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
11+
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
1212
for synthesizer in ['GaussianCopulaSynthesizer', 'TVAESynthesizer']:
1313
benchmark_single_table_aws(
1414
output_destination=run_benchmark.OUTPUT_DESTINATION_AWS,
15-
aws_access_key_id=aws_key,
16-
aws_secret_access_key=aws_secret,
15+
aws_access_key_id=aws_access_key_id,
16+
aws_secret_access_key=aws_secret_access_key,
1717
synthesizers=[synthesizer],
1818
sdv_datasets=datasets,
1919
compute_privacy_score=False,

sdgym/_run_benchmark/upload_benchmark_results.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,16 @@ def upload_results(aws_access_key_id, aws_secret_access_key, run_name, s3_client
8787

8888

8989
def main():
90-
aws_key = os.getenv('AWS_ACCESS_KEY_ID')
91-
aws_secret = os.getenv('AWS_SECRET_ACCESS_KEY')
92-
run_name, s3_client, bucket, prefix = get_run_name_and_s3_vars(aws_key, aws_secret)
90+
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
91+
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
92+
run_name, s3_client, bucket, prefix = get_run_name_and_s3_vars(
93+
aws_access_key_id, aws_secret_access_key
94+
)
9395
if upload_already_done(s3_client, bucket, prefix, run_name):
9496
LOGGER.info('Benchmark results have already been uploaded. Exiting.')
9597
sys.exit(0)
9698

97-
upload_results(aws_key, aws_secret, run_name, s3_client, bucket, prefix)
99+
upload_results(aws_access_key_id, aws_secret_access_key, run_name, s3_client, bucket, prefix)
98100

99101

100102
if __name__ == '__main__':

sdgym/benchmark.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,25 @@ def _generate_job_args_list(
242242
synthesizers = get_synthesizers(synthesizers + custom_synthesizers)
243243

244244
# Get list of dataset paths
245-
sdv_datasets = [] if sdv_datasets is None else get_dataset_paths(datasets=sdv_datasets)
245+
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
246+
aws_secret_access_key_key = os.getenv('AWS_SECRET_ACCESS_KEY')
247+
sdv_datasets = (
248+
[]
249+
if sdv_datasets is None
250+
else get_dataset_paths(
251+
datasets=sdv_datasets,
252+
aws_access_key_id=aws_access_key_id,
253+
aws_secret_access_key=aws_secret_access_key_key,
254+
)
255+
)
246256
additional_datasets = (
247257
[]
248258
if additional_datasets_folder is None
249-
else get_dataset_paths(bucket=additional_datasets_folder)
259+
else get_dataset_paths(
260+
bucket=additional_datasets_folder,
261+
aws_access_key_id=aws_access_key_id,
262+
aws_secret_access_key=aws_secret_access_key_key,
263+
)
250264
)
251265
datasets = sdv_datasets + additional_datasets
252266
synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers]

sdgym/cli/__main__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,16 @@ def _download_datasets(args):
9898
datasets = args.datasets
9999
if not datasets:
100100
datasets = sdgym.datasets.get_available_datasets(
101-
args.bucket, args.aws_key, args.aws_secret
101+
args.bucket, args.aws_access_key_id, args.aws_secret_access_key
102102
)['name']
103103

104104
for dataset in tqdm.tqdm(datasets):
105105
sdgym.datasets.load_dataset(
106-
dataset, args.datasets_path, args.bucket, args.aws_key, args.aws_secret
106+
dataset,
107+
args.datasets_path,
108+
args.bucket,
109+
args.aws_access_key_id,
110+
args.aws_secret_access_key,
107111
)
108112

109113

@@ -114,7 +118,9 @@ def _list_downloaded(args):
114118

115119

116120
def _list_available(args):
117-
datasets = sdgym.datasets.get_available_datasets(args.bucket, args.aws_key, args.aws_secret)
121+
datasets = sdgym.datasets.get_available_datasets(
122+
args.bucket, args.aws_access_key_id, args.aws_secret_access_key
123+
)
118124
_print_table(datasets, args.sort, args.reverse, {'size': humanfriendly.format_size})
119125

120126

@@ -125,16 +131,16 @@ def _list_synthesizers(args):
125131

126132
def _collect(args):
127133
sdgym.cli.collect.collect_results(
128-
args.input_path, args.output_file, args.aws_key, args.aws_secret
134+
args.input_path, args.output_file, args.aws_access_key_id, args.aws_secret_access_key
129135
)
130136

131137

132138
def _summary(args):
133139
sdgym.cli.summary.make_summary_spreadsheet(
134140
args.input_path,
135141
output_path=args.output_file,
136-
aws_key=args.aws_key,
137-
aws_secret=args.aws_secret,
142+
aws_access_key_id=args.aws_access_key_id,
143+
aws_secret_access_key=args.aws_secret_access_key,
138144
)
139145

140146

sdgym/cli/collect.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from sdgym.s3 import write_csv
55

66

7-
def collect_results(input_path, output_file=None, aws_key=None, aws_secret=None):
7+
def collect_results(
8+
input_path, output_file=None, aws_access_key_id=None, aws_secret_access_key=None
9+
):
810
"""Collect the results in the given input directory.
911
1012
Write all the results into one csv file.
@@ -15,15 +17,15 @@ def collect_results(input_path, output_file=None, aws_key=None, aws_secret=None)
1517
output_file (str):
1618
If ``output_file`` is provided, the consolidated results will be written there.
1719
Otherwise, they will be written to ``input_path``/results.csv.
18-
aws_key (str):
19-
If an ``aws_key`` is provided, the given access key id will be used to read from
20-
and/or write to any s3 paths.
21-
aws_secret (str):
22-
If an ``aws_secret`` is provided, the given secret access key will be used to read
23-
from and/or write to any s3 paths.
20+
aws_access_key_id (str):
21+
If an ``aws_access_key_id`` is provided, the given access key id will be used
22+
to read from and/or write to any s3 paths.
23+
aws_secret_access_key (str):
24+
If an ``aws_secret_access_key`` is provided, the given secret access key will
25+
be used to read from and/or write to any s3 paths.
2426
"""
2527
print(f'Reading results from {input_path}') # noqa: T201
26-
scores = read_csv_from_path(input_path, aws_key, aws_secret)
28+
scores = read_csv_from_path(input_path, aws_access_key_id, aws_secret_access_key)
2729
scores = scores.drop_duplicates()
2830

2931
if output_file:
@@ -32,4 +34,4 @@ def collect_results(input_path, output_file=None, aws_key=None, aws_secret=None)
3234
output = f'{input_path}/results.csv'
3335

3436
print(f'Storing results at {output}') # noqa: T201
35-
write_csv(scores, output, aws_key, aws_secret)
37+
write_csv(scores, output, aws_access_key_id, aws_secret_access_key)

sdgym/cli/summary.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ def _add_summary(data, modality, baselines, writer):
289289

290290

291291
def make_summary_spreadsheet(
292-
results_csv_path, output_path=None, baselines=None, aws_key=None, aws_secret=None
292+
results_csv_path,
293+
output_path=None,
294+
baselines=None,
295+
aws_access_key_id=None,
296+
aws_secret_access_key=None,
293297
):
294298
"""Create a spreadsheet document organizing information from results.
295299
@@ -307,7 +311,7 @@ def make_summary_spreadsheet(
307311
Optional dict mapping modalities to a list of baseline
308312
model names. If not provided, a default dict is used.
309313
"""
310-
results = read_csv(results_csv_path, aws_key, aws_secret)
314+
results = read_csv(results_csv_path, aws_access_key_id, aws_secret_access_key)
311315
data = preprocess(results)
312316
baselines = baselines or MODALITY_BASELINES
313317
output_path = output_path or re.sub('.csv$', '.xlsx', results_csv_path)
@@ -319,4 +323,4 @@ def make_summary_spreadsheet(
319323
_add_summary(df, modality, modality_baselines, writer)
320324

321325
writer.save()
322-
write_file(output.getvalue(), output_path, aws_key, aws_secret)
326+
write_file(output.getvalue(), output_path, aws_access_key_id, aws_secret_access_key)

sdgym/cli/utils.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,25 @@
1111
from sdgym.s3 import get_s3_client, is_s3_path, parse_s3_path
1212

1313

14-
def read_file(path, aws_key, aws_secret):
14+
def read_file(path, aws_access_key_id, aws_secret_access_key):
1515
"""Read file from path.
1616
1717
The path can either be a local path or an s3 directory.
1818
1919
Args:
2020
path (str):
2121
The path to the file.
22-
aws_key (str):
22+
aws_access_key_id (str):
2323
The access key id that will be used to communicate with s3, if provided.
24-
aws_secret (str):
24+
aws_secret_access_key (str):
2525
The secret access key that will be used to communicate with s3, if provided.
2626
2727
Returns:
2828
bytes:
2929
The content of the file in bytes.
3030
"""
3131
if is_s3_path(path):
32-
s3 = get_s3_client(aws_key, aws_secret)
32+
s3 = get_s3_client(aws_access_key_id, aws_secret_access_key)
3333
bucket_name, key = parse_s3_path(path)
3434
obj = s3.get_object(Bucket=bucket_name, Key=key)
3535
contents = obj['Body'].read()
@@ -40,28 +40,28 @@ def read_file(path, aws_key, aws_secret):
4040
return contents
4141

4242

43-
def read_csv(path, aws_key, aws_secret):
43+
def read_csv(path, aws_access_key_id, aws_secret_access_key):
4444
"""Read csv file from path.
4545
4646
The path can either be a local path or an s3 directory.
4747
4848
Args:
4949
path (str):
5050
The path to the csv file.
51-
aws_key (str):
51+
aws_access_key_id (str):
5252
The access key id that will be used to communicate with s3, if provided.
53-
aws_secret (str):
53+
aws_secret_access_key (str):
5454
The secret access key that will be used to communicate with s3, if provided.
5555
5656
Returns:
5757
pandas.DataFrame:
5858
A DataFrame containing the contents of the csv file.
5959
"""
60-
contents = read_file(path, aws_key, aws_secret)
60+
contents = read_file(path, aws_access_key_id, aws_secret_access_key)
6161
return pd.read_csv(io.BytesIO(contents))
6262

6363

64-
def read_csv_from_path(path, aws_key, aws_secret):
64+
def read_csv_from_path(path, aws_access_key_id, aws_secret_access_key):
6565
"""Read all csv content within a path.
6666
6767
All csv content within a path will be read and returned in a
@@ -70,9 +70,9 @@ def read_csv_from_path(path, aws_key, aws_secret):
7070
Args:
7171
path (str):
7272
The path to read from, which can be either local or an s3 path.
73-
aws_key (str):
73+
aws_access_key_id (str):
7474
The access key id that will be used to communicate with s3, if provided.
75-
aws_secret (str):
75+
aws_secret_access_key (str):
7676
The secret access key that will be used to communicate with s3, if provided.
7777
7878
Returns:
@@ -81,13 +81,17 @@ def read_csv_from_path(path, aws_key, aws_secret):
8181
"""
8282
csv_contents = []
8383
if is_s3_path(path):
84-
s3 = get_s3_client(aws_key, aws_secret)
84+
s3 = get_s3_client(aws_access_key_id, aws_secret_access_key)
8585
bucket_name, key_prefix = parse_s3_path(path)
8686
resp = s3.list_objects(Bucket=bucket_name, Prefix=key_prefix)
8787
csv_files = [f for f in resp['Contents'] if f['Key'].endswith('.csv')]
8888
for csv_file in csv_files:
8989
csv_file_key = csv_file['Key']
90-
csv_contents.append(read_csv(f's3://{bucket_name}/{csv_file_key}', aws_key, aws_secret))
90+
csv_contents.append(
91+
read_csv(
92+
f's3://{bucket_name}/{csv_file_key}', aws_access_key_id, aws_secret_access_key
93+
)
94+
)
9195

9296
else:
9397
run_path = pathlib.Path(path)

0 commit comments

Comments
 (0)