|
8 | 8 | import pickle |
9 | 9 | import re |
10 | 10 | import textwrap |
| 11 | +import threading |
11 | 12 | import tracemalloc |
12 | 13 | import warnings |
13 | 14 | from collections import defaultdict |
|
42 | 43 | from sdgym.errors import SDGymError |
43 | 44 | from sdgym.metrics import get_metrics |
44 | 45 | from sdgym.progress import TqdmLogger, progress |
45 | | -from sdgym.result_writer import LocalResultsWriter |
| 46 | +from sdgym.result_writer import LocalResultsWriter, S3ResultsWriter |
46 | 47 | from sdgym.s3 import ( |
47 | 48 | S3_PREFIX, |
48 | 49 | S3_REGION, |
@@ -544,56 +545,52 @@ def _score_with_timeout( |
544 | 545 | synthesizer_path=None, |
545 | 546 | result_writer=None, |
546 | 547 | ): |
| 548 | + output = {} |
| 549 | + args = ( |
| 550 | + synthesizer, |
| 551 | + data, |
| 552 | + metadata, |
| 553 | + metrics, |
| 554 | + output, |
| 555 | + compute_quality_score, |
| 556 | + compute_diagnostic_score, |
| 557 | + compute_privacy_score, |
| 558 | + modality, |
| 559 | + dataset_name, |
| 560 | + synthesizer_path, |
| 561 | + result_writer, |
| 562 | + ) |
| 563 | + if isinstance(result_writer, S3ResultsWriter): |
| 564 | + process = threading.Thread( |
| 565 | + target=_score, |
| 566 | + args=args, |
| 567 | + daemon=True, |
| 568 | + ) |
| 569 | + process.start() |
| 570 | + process.join(timeout) |
| 571 | + if process.is_alive(): |
| 572 | + LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name) |
| 573 | + return {'timeout': True, 'error': 'Timeout'} |
| 574 | + |
| 575 | + return process.result |
| 576 | + |
547 | 577 | with multiprocessing_context(): |
548 | 578 | with multiprocessing.Manager() as manager: |
549 | 579 | output = manager.dict() |
550 | | - |
551 | | - def safe_score(*args): |
552 | | - try: |
553 | | - _score(*args) |
554 | | - except Exception as e: |
555 | | - output['error'] = str(e) |
556 | | - |
557 | 580 | process = multiprocessing.Process( |
558 | | - target=safe_score, |
559 | | - args=( |
560 | | - synthesizer, |
561 | | - data, |
562 | | - metadata, |
563 | | - metrics, |
564 | | - output, |
565 | | - compute_quality_score, |
566 | | - compute_diagnostic_score, |
567 | | - compute_privacy_score, |
568 | | - modality, |
569 | | - dataset_name, |
570 | | - synthesizer_path, |
571 | | - result_writer, |
572 | | - ), |
| 581 | + target=_score, |
| 582 | + args=args, |
573 | 583 | ) |
574 | 584 |
|
575 | 585 | process.start() |
576 | 586 | process.join(timeout) |
| 587 | + process.terminate() |
577 | 588 |
|
578 | | - if process.is_alive(): |
579 | | - output['timeout'] = True |
580 | | - process.terminate() |
581 | | - process.join() # ensure termination completes |
582 | | - |
583 | | - result = dict(output) |
584 | | - if result.get('timeout'): |
585 | | - LOGGER.error( |
586 | | - 'Timeout running %s on dataset %s', |
587 | | - synthesizer['name'], dataset_name |
588 | | - ) |
589 | | - elif result.get('error'): |
590 | | - LOGGER.error( |
591 | | - 'Error running %s on dataset %s: %s', |
592 | | - synthesizer['name'], dataset_name, result['error'] |
593 | | - ) |
594 | | - |
595 | | - return result |
| 589 | + output = dict(output) |
| 590 | + if output.get('timeout'): |
| 591 | + LOGGER.error('Timeout running %s on dataset %s;', synthesizer['name'], dataset_name) |
596 | 592 |
|
| 593 | + return output |
597 | 594 |
|
598 | 595 |
|
599 | 596 | def _format_output( |
@@ -1293,7 +1290,7 @@ def _get_s3_script_content( |
1293 | 1290 | scores = _run_jobs(None, job_args_list, False, result_writer=result_writer) |
1294 | 1291 | run_id_filename = job_args_list[0][-1]['run_id'] |
1295 | 1292 | _update_run_id_file(run_id_filename, result_writer) |
1296 | | -#s3_client.delete_object(Bucket='{bucket_name}', Key='{job_args_key}') |
| 1293 | +s3_client.delete_object(Bucket='{bucket_name}', Key='{job_args_key}') |
1297 | 1294 | """ |
1298 | 1295 |
|
1299 | 1296 |
|
|
0 commit comments