Skip to content

Commit 3a76b31

Browse files
committed
cleaning
1 parent 962ae20 commit 3a76b31

File tree

3 files changed

+13
-35
lines changed

3 files changed

+13
-35
lines changed

sdgym/benchmark.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from sdgym.result_writer import LocalResultsWriter
4646
from sdgym.s3 import (
4747
S3_PREFIX,
48+
S3_REGION,
4849
is_s3_path,
4950
parse_s3_path,
5051
write_csv,
@@ -331,7 +332,6 @@ def _compute_scores(
331332
modality,
332333
dataset_name,
333334
):
334-
LOGGER.info('ROM Computing scores for dataset %s', dataset_name)
335335
metrics = metrics or []
336336
if len(metrics) > 0:
337337
metrics, metric_kwargs = get_metrics(metrics, modality='single-table')
@@ -369,7 +369,6 @@ def _compute_scores(
369369
# re-inject list to multiprocessing output
370370
output['scores'] = scores
371371

372-
LOGGER.info('ROM before diagnostic score')
373372
if compute_diagnostic_score:
374373
start = get_utc_now()
375374
if modality == 'single_table':
@@ -381,7 +380,6 @@ def _compute_scores(
381380
output['diagnostic_score_time'] = calculate_score_time(start)
382381
output['diagnostic_score'] = diagnostic_report.get_score()
383382

384-
LOGGER.info('ROM before quality score')
385383
if compute_quality_score:
386384
start = get_utc_now()
387385
if modality == 'single_table':
@@ -390,12 +388,9 @@ def _compute_scores(
390388
quality_report = MultiTableQualityReport()
391389

392390
quality_report.generate(real_data, synthetic_data, metadata, verbose=False)
393-
LOGGER.info('ROM Quality report generated')
394391
output['quality_score_time'] = calculate_score_time(start)
395-
LOGGER.info('ROM before quality score get_score')
396392
output['quality_score'] = quality_report.get_score()
397393

398-
LOGGER.info('ROM before privacy score')
399394
if compute_privacy_score:
400395
start = get_utc_now()
401396
num_rows = len(synthetic_data)
@@ -1202,7 +1197,7 @@ def _validate_aws_inputs(output_destination, aws_access_key_id, aws_secret_acces
12021197
's3',
12031198
aws_access_key_id=aws_access_key_id,
12041199
aws_secret_access_key=aws_secret_access_key,
1205-
region_name='us-east-1',
1200+
region_name=S3_REGION,
12061201
config=config,
12071202
)
12081203
else:
@@ -1248,31 +1243,9 @@ def _get_s3_script_content(
12481243
return f"""
12491244
import boto3
12501245
import pickle
1251-
import base64
1252-
import pandas as pd
1253-
import sdgym
1254-
import logging
1255-
from sdgym.synthesizers.sdv import (
1256-
CopulaGANSynthesizer, CTGANSynthesizer,
1257-
GaussianCopulaSynthesizer, HMASynthesizer, PARSynthesizer,
1258-
SDVRelationalSynthesizer, SDVTabularSynthesizer, TVAESynthesizer
1259-
)
1260-
from sdgym.synthesizers import RealTabFormerSynthesizer
12611246
from sdgym.benchmark import _run_jobs, _write_run_id_file, _update_run_id_file
12621247
from io import StringIO
12631248
from sdgym.result_writer import S3ResultsWriter
1264-
import sys
1265-
1266-
logging.basicConfig(
1267-
level=logging.INFO,
1268-
format='%(asctime)s - %(levelname)s - %(message)s',
1269-
stream=sys.stdout
1270-
)
1271-
1272-
LOGGER = logging.getLogger(__name__)
1273-
LOGGER.info("This should show up on CloudWatch / logs")
1274-
1275-
12761249
12771250
s3_client = boto3.client(
12781251
's3',
@@ -1337,11 +1310,10 @@ def _run_on_aws(
13371310
aws_secret_access_key,
13381311
):
13391312
bucket_name, job_args_key = _store_job_args_in_s3(output_destination, job_args_list, s3_client)
1340-
region_name = 'us-east-1'
13411313
script_content = _get_s3_script_content(
13421314
aws_access_key_id,
13431315
aws_secret_access_key,
1344-
region_name,
1316+
S3_REGION,
13451317
bucket_name,
13461318
job_args_key,
13471319
synthesizers,
@@ -1351,12 +1323,12 @@ def _run_on_aws(
13511323
session = boto3.session.Session(
13521324
aws_access_key_id=aws_access_key_id,
13531325
aws_secret_access_key=aws_secret_access_key,
1354-
region_name=region_name,
1326+
region_name=S3_REGION,
13551327
)
13561328
ec2_client = session.client('ec2')
13571329
print(f'This instance is being created in region: {session.region_name}') # noqa
13581330
user_data_script = _get_user_data_script(
1359-
aws_access_key_id, aws_secret_access_key, region_name, script_content
1331+
aws_access_key_id, aws_secret_access_key, S3_REGION, script_content
13601332
)
13611333
response = ec2_client.run_instances(
13621334
ImageId='ami-080e1f13689e07408',

sdgym/s3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pandas as pd
1111

1212
S3_PREFIX = 's3://'
13+
S3_REGION = 'us-east-1'
1314
LOGGER = logging.getLogger(__name__)
1415

1516

tests/unit/test_benchmark.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
benchmark_single_table_aws,
2626
)
2727
from sdgym.result_writer import LocalResultsWriter
28+
from sdgym.s3 import S3_REGION
2829
from sdgym.synthesizers import GaussianCopulaSynthesizer
2930

3031

@@ -560,7 +561,11 @@ def test_validate_aws_inputs_valid(mock_config, mock_check_write_permissions, mo
560561

561562
# Assert
562563
mock_boto3_client.assert_called_once_with(
563-
's3', aws_access_key_id='AKIA...', aws_secret_access_key='SECRET', config=config_mock
564+
's3',
565+
aws_access_key_id='AKIA...',
566+
aws_secret_access_key='SECRET',
567+
region_name=S3_REGION,
568+
config=config_mock,
564569
)
565570
s3_client_mock.head_bucket.assert_called_once_with(Bucket='my-bucket')
566571
mock_check_write_permissions.assert_called_once_with(s3_client_mock, 'my-bucket')
@@ -654,7 +659,7 @@ def test_benchmark_single_table_aws(
654659
output_destination=output_destination,
655660
compute_quality_score=True,
656661
compute_diagnostic_score=True,
657-
compute_privacy_score=True,
662+
compute_privacy_score=False,
658663
synthesizers=synthesizers,
659664
detailed_results_folder=None,
660665
custom_synthesizers=None,

0 commit comments

Comments
 (0)