Skip to content

Commit 0b3f0f6

Browse files
committed
debug run with timeout 3
1 parent 2556dc4 commit 0b3f0f6

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

sdgym/benchmark.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,15 @@ def _score_with_timeout(
547547
with multiprocessing_context():
548548
with multiprocessing.Manager() as manager:
549549
output = manager.dict()
550+
551+
def safe_score(*args):
552+
try:
553+
_score(*args)
554+
except Exception as e:
555+
output['error'] = str(e)
556+
550557
process = multiprocessing.Process(
551-
target=_score,
558+
target=safe_score,
552559
args=(
553560
synthesizer,
554561
data,
@@ -567,13 +574,26 @@ def _score_with_timeout(
567574

568575
process.start()
569576
process.join(timeout)
570-
process.terminate()
571577

572-
output = dict(output)
573-
if output.get('timeout'):
574-
LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name)
578+
if process.is_alive():
579+
output['timeout'] = True
580+
process.terminate()
581+
process.join() # ensure termination completes
582+
583+
result = dict(output)
584+
if result.get('timeout'):
585+
LOGGER.error(
586+
'Timeout running %s on dataset %s',
587+
synthesizer['name'], dataset_name
588+
)
589+
elif result.get('error'):
590+
LOGGER.error(
591+
'Error running %s on dataset %s: %s',
592+
synthesizer['name'], dataset_name, result['error']
593+
)
594+
595+
return result
575596

576-
return output
577597

578598

579599
def _format_output(
@@ -677,8 +697,6 @@ def _run_job(args):
677697
output = {}
678698
try:
679699
if timeout:
680-
print('LAAA')
681-
print(timeout)
682700
output = _score_with_timeout(
683701
timeout=timeout,
684702
synthesizer=synthesizer,

sdgym/run_benchmark/run_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def main():
4646
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
4747
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
4848
date_str = datetime.now(timezone.utc).strftime('%Y-%m-%d')
49-
for synthesizer_group in SYNTHESIZERS_SPLIT:
49+
for synthesizer_group in SYNTHESIZERS_SPLIT[:2]:
5050
benchmark_single_table_aws(
5151
output_destination=OUTPUT_DESTINATION_AWS,
5252
aws_access_key_id=aws_access_key_id,

0 commit comments

Comments
 (0)