Skip to content

Commit d9ba6e7

Browse files
authored
Simplify the import API for SDGym's results explorer (#432)
1 parent b4e8813 commit d9ba6e7

File tree

22 files changed

+42
-44
lines changed

22 files changed

+42
-44
lines changed

pyproject.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,8 @@ dependencies = [
4747
"scipy>=1.12.0;python_version>='3.12' and python_version<'3.13'",
4848
"scipy>=1.14.1;python_version>='3.13'",
4949
'tabulate>=0.8.3,<0.9',
50-
"torch>=1.13.0;python_version<'3.11'",
51-
"torch>=2.0.0;python_version>='3.11' and python_version<'3.12'",
52-
"torch>=2.2.0;python_version>='3.12' and python_version<'3.13'",
53-
"torch>=2.6.0;python_version>='3.13'",
50+
"torch>=2.2.0;python_version>='3.8' and python_version<'3.9'",
51+
"torch>=2.6.0;python_version>='3.9'",
5452
'tqdm>=4.66.3',
5553
'XlsxWriter>=1.2.8',
5654
'rdt>=1.17.0',

sdgym/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sdgym.cli.summary import make_summary_spreadsheet
1818
from sdgym.datasets import get_available_datasets, load_dataset
1919
from sdgym.synthesizers import create_sdv_synthesizer_variant, create_single_table_synthesizer
20+
from sdgym.result_explorer import ResultsExplorer
2021

2122
# Clear the logging wrongfully configured by tensorflow/absl
2223
list(map(logging.root.removeHandler, logging.root.handlers))
@@ -30,4 +31,5 @@
3031
'get_available_datasets',
3132
'create_sdv_synthesizer_variant',
3233
'create_single_table_synthesizer',
34+
'ResultsExplorer',
3335
]

sdgym/progress.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ def _draw_bar(self, remaining, total, **kwargs):
5656
eta = datetime.utcnow() + remaining_time
5757

5858
elapsed = timedelta(seconds=self.elapsed)
59-
msg = ( # noqa: SFS201
60-
'[{0:<{1}}] | {2}/{3} ({4}%) Completed | {5} | {6} | {7}'
61-
).format(
59+
msg = ('[{0:<{1}}] | {2}/{3} ({4}%) Completed | {5} | {6} | {7}').format(
6260
progress_bar, self.width, done, total, percent, elapsed, remaining_time, eta
6361
)
6462
self.logger.info(msg)

sdgym/result_explorer/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Benchmark Results Explorer for SDGym."""
2+
3+
from sdgym.result_explorer.result_explorer import ResultsExplorer
4+
5+
__all__ = ['ResultsExplorer']

sdgym/sdgym_result_explorer/result_explorer.py renamed to sdgym/result_explorer/result_explorer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from sdgym.benchmark import DEFAULT_DATASETS
66
from sdgym.datasets import get_dataset_paths, load_dataset
7+
from sdgym.result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler
78
from sdgym.s3 import _get_s3_client, is_s3_path
8-
from sdgym.sdgym_result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler
99

1010

1111
def _validate_local_path(path):
@@ -14,7 +14,7 @@ def _validate_local_path(path):
1414
raise ValueError(f"The provided path '{path}' is not a valid local directory.")
1515

1616

17-
class SDGymResultsExplorer:
17+
class ResultsExplorer:
1818
"""Explorer for SDGym benchmark results, supporting both local and S3 storage."""
1919

2020
def __init__(self, path, aws_access_key_id=None, aws_secret_access_key=None):
File renamed without changes.

sdgym/run_benchmark/upload_benchmark_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
from pydrive2.auth import GoogleAuth
1515
from pydrive2.drive import GoogleDrive
1616

17+
from sdgym.result_explorer.result_explorer import ResultsExplorer
1718
from sdgym.result_writer import LocalResultsWriter
1819
from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, get_df_to_plot
1920
from sdgym.s3 import S3_REGION, parse_s3_path
20-
from sdgym.sdgym_result_explorer.result_explorer import SDGymResultsExplorer
2121

2222
LOGGER = logging.getLogger(__name__)
2323
SYNTHESIZER_TO_GLOBAL_POSITION = {
@@ -114,7 +114,7 @@ def upload_results(
114114
"""Upload benchmark results to S3, GDrive, and save locally."""
115115
folder_name = folder_infos['folder_name']
116116
run_date = folder_infos['date']
117-
result_explorer = SDGymResultsExplorer(
117+
result_explorer = ResultsExplorer(
118118
OUTPUT_DESTINATION_AWS,
119119
aws_access_key_id=aws_access_key_id,
120120
aws_secret_access_key=aws_secret_access_key,

sdgym/sdgym_result_explorer/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.
File renamed without changes.

tests/integration/sdgym_result_explorer/_benchmark_results/SDGym_results_04_05_2024/results_04_05_2024_1.csv renamed to tests/integration/result_explorer/_benchmark_results/SDGym_results_04_05_2024/results_04_05_2024_1.csv

File renamed without changes.

0 commit comments

Comments
 (0)