Skip to content

Commit 0e15d4a

Browse files
committed
Add option to allow newlines in captions
The YFCC-15M descriptions can have new lines in the caption, which causes parquet's csv module to error by default. This commit allows passing --newlines-in-captions True to img2dataset, which will tell parquet to allow newlines in CSV values.
1 parent fc3fb2e commit 0e15d4a

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

img2dataset/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def download(
104104
max_shard_retry: int = 1,
105105
user_agent_token: Optional[str] = None,
106106
disallowed_header_directives: Optional[List[str]] = None,
107+
newlines_in_captions: bool = False,
107108
):
108109
"""Download is the main entry point of img2dataset, it uses multiple processes and download multiple files"""
109110
if disallowed_header_directives is None:
@@ -183,6 +184,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
183184
number_sample_per_shard,
184185
done_shards,
185186
tmp_path,
187+
newlines_in_captions,
186188
)
187189

188190
if output_format == "webdataset":

img2dataset/reader.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
number_sample_per_shard,
3939
done_shards,
4040
tmp_path,
41+
newlines_in_captions,
4142
) -> None:
4243
self.input_format = input_format
4344
self.url_col = url_col
@@ -47,6 +48,7 @@ def __init__(
4748
self.save_additional_columns = save_additional_columns
4849
self.number_sample_per_shard = number_sample_per_shard
4950
self.done_shards = done_shards
51+
self.newlines_in_captions = newlines_in_captions
5052

5153
fs, url_path = fsspec.core.url_to_fs(url_list)
5254
self.fs = fs
@@ -79,13 +81,22 @@ def _save_to_arrow(self, input_file, start_shard_id):
7981
if self.input_format in ["txt", "json", "csv", "tsv"]:
8082
with self.fs.open(input_file, mode="rb") as file:
8183
if self.input_format == "txt":
82-
df = csv_pq.read_csv(file, read_options=csv_pq.ReadOptions(column_names=["url"]))
84+
df = csv_pq.read_csv(
85+
file,
86+
read_options=csv_pq.ReadOptions(column_names=["url"]),
87+
parse_options=csv_pq.ParseOptions(newlines_in_values=self.newlines_in_captions),
88+
)
8389
elif self.input_format == "json":
8490
df = pa.Table.from_pandas(pd.read_json(file))
8591
elif self.input_format == "csv":
86-
df = csv_pq.read_csv(file)
92+
df = csv_pq.read_csv(
93+
file, parse_options=csv_pq.ParseOptions(newlines_in_values=self.newlines_in_captions)
94+
)
8795
elif self.input_format == "tsv":
88-
df = csv_pq.read_csv(file, parse_options=csv_pq.ParseOptions(delimiter="\t"))
96+
df = csv_pq.read_csv(
97+
file,
98+
parse_options=csv_pq.ParseOptions(delimiter="\t", newlines_in_values=self.newlines_in_captions),
99+
)
89100
else:
90101
raise ValueError(f"Unknown input format {self.input_format}")
91102
elif self.input_format == "tsv.gz":

tests/test_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_reader(input_format, tmp_path):
4949
number_sample_per_shard=batch_size,
5050
done_shards=done_shards,
5151
tmp_path=test_folder,
52+
newlines_in_captions=False,
5253
)
5354

5455
if input_format == "txt":

0 commit comments

Comments
 (0)