Skip to content

Commit 5d266a8

Browse files
committed
Feedback
1 parent 392f8b1 commit 5d266a8

File tree

4 files changed

+15
-30
lines changed

4 files changed

+15
-30
lines changed

sdgym/datasets.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,18 @@ def _get_bucket_name(bucket):
3434
return bucket[len(S3_PREFIX) :] if bucket.startswith(S3_PREFIX) else bucket
3535

3636

37-
def _validate_dataset_availability(
37+
def _raise_dataset_not_found_error(
3838
s3_client,
3939
bucket_name,
4040
dataset_name,
4141
current_modality,
42-
display_name,
4342
bucket,
4443
modality,
4544
):
46-
"""Return modalities where the dataset exists in the bucket, excluding the current modality."""
45+
display_name = dataset_name
46+
if isinstance(dataset_name, Path):
47+
display_name = dataset_name.name
48+
4749
available_modalities = []
4850
for other_modality in MODALITIES:
4951
if other_modality == current_modality:
@@ -85,12 +87,8 @@ def _download_dataset(
8587

8688
contents = _list_s3_bucket_contents(s3_client, bucket_name, prefix)
8789
if not contents:
88-
display_name = dataset_name
89-
if isinstance(dataset_name, Path):
90-
display_name = dataset_name.name
91-
92-
_validate_dataset_availability(
93-
s3_client, bucket_name, dataset_name, modality, display_name, bucket, modality
90+
_raise_dataset_not_found_error(
91+
s3_client, bucket_name, dataset_name, modality, bucket, modality
9492
)
9593

9694
for obj in contents:

sdgym/result_explorer/result_explorer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,12 @@ def _validate_local_path(path):
2626

2727
def _resolve_effective_path(path, modality):
2828
"""Append the modality folder to the given base path if provided."""
29-
if not modality:
30-
return path
31-
32-
_validate_modality(modality)
33-
3429
# Avoid double-appending if already included
3530
if str(path).rstrip('/').endswith(('/' + modality, modality)):
3631
return path
3732

3833
if is_s3_path(path):
39-
path = path.rstrip('/') + '/' + modality
40-
return path
34+
return path.rstrip('/') + '/' + modality
4135

4236
return os.path.join(path, modality)
4337

@@ -59,12 +53,15 @@ def _create_results_handler(self, original_path, effective_path):
5953
_validate_local_path(effective_path)
6054
return LocalResultsHandler(effective_path, baseline_synthesizer=baseline_synthesizer)
6155

62-
def __init__(self, path, modality, aws_access_key_id=None, aws_secret_access_key=None):
56+
def __init__(
57+
self, path, modality='single_table', aws_access_key_id=None, aws_secret_access_key=None
58+
):
6359
self.path = path
60+
_validate_modality(modality)
6461
self.modality = modality.lower()
6562
self.aws_access_key_id = aws_access_key_id
6663
self.aws_secret_access_key = aws_secret_access_key
67-
effective_path = _resolve_effective_path(path, modality)
64+
effective_path = _resolve_effective_path(path, self.modality)
6865
self._handler = self._create_results_handler(path, effective_path)
6966

7067
def list(self):

sdgym/result_explorer/result_handler.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import os
66
from abc import ABC, abstractmethod
77
from datetime import datetime
8-
from io import BytesIO
9-
from zipfile import ZipFile
108

119
import cloudpickle
1210
import pandas as pd
@@ -27,7 +25,6 @@ class ResultsHandler(ABC):
2725
"""Abstract base class for handling results storage and retrieval."""
2826

2927
def __init__(self, baseline_synthesizer=SYNTHESIZER_BASELINE):
30-
# Allow overrides per modality while maintaining the historical default.
3128
self.baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE
3229

3330
@abstractmethod
@@ -395,15 +392,7 @@ def load_synthetic_data(self, file_path):
395392
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
396393
body = response['Body'].read()
397394
if file_path.endswith('.zip'):
398-
tables = {}
399-
with ZipFile(BytesIO(body)) as zf:
400-
for name in zf.namelist():
401-
if name.endswith('.csv'):
402-
table_name = os.path.splitext(os.path.basename(name))[0]
403-
with zf.open(name) as csv_file:
404-
tables[table_name] = pd.read_csv(csv_file)
405-
406-
return tables
395+
return _read_zipped_data(io.BytesIO(body), modality='multi_table')
407396

408397
return pd.read_csv(io.BytesIO(body))
409398

sdgym/run_benchmark/upload_benchmark_results.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def upload_results(
116116
run_date = folder_infos['date']
117117
result_explorer = ResultsExplorer(
118118
OUTPUT_DESTINATION_AWS,
119+
modality='single_table',
119120
aws_access_key_id=aws_access_key_id,
120121
aws_secret_access_key=aws_secret_access_key,
121122
)

0 commit comments

Comments
 (0)