Skip to content

Commit b93a237

Browse files
committed
debug run with timeout 4
1 parent 0b3f0f6 commit b93a237

File tree

2 files changed

+40
-44
lines changed

2 files changed

+40
-44
lines changed

sdgym/benchmark.py

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pickle
99
import re
1010
import textwrap
11+
import threading
1112
import tracemalloc
1213
import warnings
1314
from collections import defaultdict
@@ -42,7 +43,7 @@
4243
from sdgym.errors import SDGymError
4344
from sdgym.metrics import get_metrics
4445
from sdgym.progress import TqdmLogger, progress
45-
from sdgym.result_writer import LocalResultsWriter
46+
from sdgym.result_writer import LocalResultsWriter, S3ResultsWriter
4647
from sdgym.s3 import (
4748
S3_PREFIX,
4849
S3_REGION,
@@ -544,56 +545,52 @@ def _score_with_timeout(
544545
synthesizer_path=None,
545546
result_writer=None,
546547
):
548+
output = {}
549+
args = (
550+
synthesizer,
551+
data,
552+
metadata,
553+
metrics,
554+
output,
555+
compute_quality_score,
556+
compute_diagnostic_score,
557+
compute_privacy_score,
558+
modality,
559+
dataset_name,
560+
synthesizer_path,
561+
result_writer,
562+
)
563+
if isinstance(result_writer, S3ResultsWriter):
564+
process = threading.Thread(
565+
target=_score,
566+
args=args,
567+
daemon=True,
568+
)
569+
process.start()
570+
process.join(timeout)
571+
if process.is_alive():
572+
LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name)
573+
return {'timeout': True, 'error': 'Timeout'}
574+
575+
return process.result
576+
547577
with multiprocessing_context():
548578
with multiprocessing.Manager() as manager:
549579
output = manager.dict()
550-
551-
def safe_score(*args):
552-
try:
553-
_score(*args)
554-
except Exception as e:
555-
output['error'] = str(e)
556-
557580
process = multiprocessing.Process(
558-
target=safe_score,
559-
args=(
560-
synthesizer,
561-
data,
562-
metadata,
563-
metrics,
564-
output,
565-
compute_quality_score,
566-
compute_diagnostic_score,
567-
compute_privacy_score,
568-
modality,
569-
dataset_name,
570-
synthesizer_path,
571-
result_writer,
572-
),
581+
target=_score,
582+
args=args,
573583
)
574584

575585
process.start()
576586
process.join(timeout)
587+
process.terminate()
577588

578-
if process.is_alive():
579-
output['timeout'] = True
580-
process.terminate()
581-
process.join() # ensure termination completes
582-
583-
result = dict(output)
584-
if result.get('timeout'):
585-
LOGGER.error(
586-
'Timeout running %s on dataset %s',
587-
synthesizer['name'], dataset_name
588-
)
589-
elif result.get('error'):
590-
LOGGER.error(
591-
'Error running %s on dataset %s: %s',
592-
synthesizer['name'], dataset_name, result['error']
593-
)
594-
595-
return result
589+
output = dict(output)
590+
if output.get('timeout'):
591+
LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name)
596592

593+
return output
597594

598595

599596
def _format_output(
@@ -1293,7 +1290,7 @@ def _get_s3_script_content(
12931290
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
12941291
run_id_filename = job_args_list[0][-1]['run_id']
12951292
_update_run_id_file(run_id_filename, result_writer)
1296-
#s3_client.delete_object(Bucket='{bucket_name}', Key='{job_args_key}')
1293+
s3_client.delete_object(Bucket='{bucket_name}', Key='{job_args_key}')
12971294
"""
12981295

12991296

sdgym/run_benchmark/run_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
OUTPUT_DESTINATION_AWS,
1313
SYNTHESIZERS_SPLIT,
1414
get_result_folder_name,
15-
post_benchmark_launch_message,
1615
)
1716
from sdgym.s3 import get_s3_client, parse_s3_path
1817

@@ -58,7 +57,7 @@ def main():
5857
)
5958

6059
append_benchmark_run(aws_access_key_id, aws_secret_access_key, date_str)
61-
#post_benchmark_launch_message(date_str)
60+
# post_benchmark_launch_message(date_str)
6261

6362

6463
if __name__ == '__main__':

0 commit comments

Comments
 (0)