Skip to content

Commit bd28e4e

Browse files
committed
Update message for text files
1 parent 9f133a5 commit bd28e4e

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

sdv/datasets/demo.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,22 @@ def is_direct_json_under_prefix(key):
197197
)
198198

199199

200+
def _download_text_file_error_message(
201+
modality,
202+
dataset_name,
203+
output_filepath=None,
204+
bucket=PUBLIC_BUCKET,
205+
filename=None,
206+
**kwargs,
207+
):
208+
return (
209+
f"Could not retrieve '{filename}' for dataset '{dataset_name}' "
210+
f"from bucket '{bucket}'. "
211+
'Make sure the bucket name is correct. If the bucket is private '
212+
'make sure to provide your credentials.'
213+
)
214+
215+
200216
def _download_error_message(
201217
modality,
202218
dataset_name,
@@ -640,6 +656,7 @@ def _save_document(text, output_filepath, filename, dataset_name):
640656
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')
641657

642658

659+
@handle_aws_client_errors(_download_text_file_error_message)
643660
def _get_text_file_content(
644661
modality, dataset_name, filename, output_filepath=None, bucket=PUBLIC_BUCKET, credentials=None
645662
):
@@ -690,7 +707,7 @@ def _get_text_file_content(
690707
return text
691708

692709

693-
@handle_aws_client_errors(_download_error_message)
710+
@handle_aws_client_errors(_download_text_file_error_message)
694711
def get_source(
695712
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
696713
):
@@ -727,7 +744,6 @@ def get_source(
727744
)
728745

729746

730-
@handle_aws_client_errors(_download_error_message)
731747
def get_readme(
732748
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
733749
):

tests/unit/datasets/test_demo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,8 @@ def test_get_source_with_client_error(mock__create_s3_client):
14651465
mock__create_s3_client.return_value = client
14661466

14671467
error_msg = (
1468-
"Could not download dataset 'fake_hotels' from bucket 'private_bucket'. "
1468+
"Could not retrieve 'SOURCE.txt' for dataset 'fake_hotels' "
1469+
"from bucket 'private_bucket'. "
14691470
'Make sure the bucket name is correct. If the bucket is private '
14701471
'make sure to provide your credentials.'
14711472
)
@@ -1499,7 +1500,8 @@ def test_get_readme_with_client_error(mock__create_s3_client):
14991500
mock__create_s3_client.return_value = client
15001501

15011502
error_msg = (
1502-
"Could not download dataset 'fake_hotels' from bucket 'private_bucket'. "
1503+
"Could not retrieve 'README.txt' for dataset 'fake_hotels' "
1504+
"from bucket 'private_bucket'. "
15031505
'Make sure the bucket name is correct. If the bucket is private '
15041506
'make sure to provide your credentials.'
15051507
)

0 commit comments

Comments
 (0)