diff --git a/img2dataset/main.py b/img2dataset/main.py index 4150c19..fbbc688 100644 --- a/img2dataset/main.py +++ b/img2dataset/main.py @@ -108,6 +108,7 @@ def download( max_shard_retry: int = 1, user_agent_token: Optional[str] = None, disallowed_header_directives: Optional[List[str]] = None, + newlines_in_captions: bool = False, ): """Download is the main entry point of img2dataset, it uses multiple processes and download multiple files""" if disallowed_header_directives is None: @@ -192,6 +193,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument number_sample_per_shard, done_shards, tmp_path, + newlines_in_captions, start_shard_id, ) diff --git a/img2dataset/reader.py b/img2dataset/reader.py index 2c17c3b..2546782 100644 --- a/img2dataset/reader.py +++ b/img2dataset/reader.py @@ -40,6 +40,7 @@ def __init__( number_sample_per_shard, done_shards, tmp_path, + newlines_in_captions, start_shard_id: int = 0, ) -> None: self.input_format = input_format @@ -50,6 +51,7 @@ def __init__( self.save_additional_columns = save_additional_columns self.number_sample_per_shard = number_sample_per_shard self.done_shards = done_shards + self.newlines_in_captions = newlines_in_captions self.start_shard_id = start_shard_id fs, url_path = fsspec.core.url_to_fs(url_list) @@ -97,13 +99,22 @@ def _save_to_arrow(self, input_file, start_shard_id): compression = "gzip" with self.fs.open(input_file, encoding="utf-8", mode="rb", compression=compression) as file: if self.input_format in ["txt", "txt.gz"]: - df = csv_pa.read_csv(file, read_options=csv_pa.ReadOptions(column_names=["url"])) + df = csv_pa.read_csv( + file, + read_options=csv_pa.ReadOptions(column_names=["url"]), + parse_options=csv_pa.ParseOptions(newlines_in_values=self.newlines_in_captions), + ) elif self.input_format in ["json", "json.gz"]: df = pa.Table.from_pandas(pd.read_json(file)) elif self.input_format in ["csv", "csv.gz"]: - df = csv_pa.read_csv(file) + df = csv_pa.read_csv( + file, parse_options=csv_pa.ParseOptions(newlines_in_values=self.newlines_in_captions) + ) elif self.input_format in ["tsv", "tsv.gz"]: - df = csv_pa.read_csv(file, parse_options=csv_pa.ParseOptions(delimiter="\t")) + df = csv_pa.read_csv( + file, + parse_options=csv_pa.ParseOptions(delimiter="\t", newlines_in_values=self.newlines_in_captions), + ) elif self.input_format in ["jsonl", "jsonl.gz"]: df = json_pa.read_json(file) else: diff --git a/tests/test_reader.py b/tests/test_reader.py index da483a6..d3d67ed 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -54,6 +54,7 @@ def test_reader(input_format, tmp_path): number_sample_per_shard=batch_size, done_shards=done_shards, tmp_path=test_folder, + newlines_in_captions=False, ) if input_format in ["txt", "txt.gz"]: