Skip to content

Commit 42fc209

Browse files
authored
Add job to write dict contents to a csv file (#606)
Used to adhere to `csv.reader()` and escape inputs with possible issues
1 parent f48ebf5 commit 42fc209

File tree

1 file changed

+72
-13
lines changed

1 file changed

+72
-13
lines changed

text/processing.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55
"TailJob",
66
"SetDifferenceJob",
77
"WriteToTextFileJob",
8+
"WriteToCsvFileJob",
89
"SplitTextFileJob",
910
]
1011

12+
import csv
13+
from io import IOBase
1114
import logging
1215
import os
1316
import shutil
1417
import subprocess
1518
from collections.abc import Iterable
1619
import tempfile
17-
from typing import List, Optional, Union
20+
from typing import Dict, List, Optional, Union
1821

1922
from sisyphus import Job, Task, Path, global_settings as gs, toolkit as tk
2023
from sisyphus.delayed_ops import DelayedBase
@@ -282,7 +285,17 @@ def run(self):
282285

283286
class WriteToTextFileJob(Job):
284287
"""
285-
Write a given content into a text file, one entry per line
288+
Write a given content into a text file, one entry per line.
289+
290+
This job supports multiple input types:
291+
1. String.
292+
2. Dictionary.
293+
3. Iterable.
294+
295+
The corresponding output for each of the inputs above is:
296+
1. The string is directly written into the file.
297+
2. Each key/value pair is written as `<key>: <value>`.
298+
3. Each element in the iterable is written in a separate line as a string.
286299
"""
287300

288301
__sis_hash_exclude__ = {"out_name": "file.txt"}
@@ -296,22 +309,68 @@ def __init__(self, content: Union[str, dict, Iterable, DelayedBase], out_name: s
296309

297310
self.out_file = self.output_path(out_name)
298311

312+
def write_content_to_file(self, file_handler: IOBase):
313+
content = util.instanciate_delayed(self.content)
314+
if isinstance(content, str):
315+
file_handler.write(content)
316+
elif isinstance(content, dict):
317+
for key, val in content.items():
318+
file_handler.write(f"{key}: {val}\n")
319+
elif isinstance(content, Iterable):
320+
for line in content:
321+
file_handler.write(f"{line}\n")
322+
else:
323+
raise NotImplementedError("Content of unknown type different from (str, dict, Iterable).")
324+
299325
def tasks(self):
300326
yield Task("run", mini_task=True)
301327

302328
def run(self):
303-
content = util.instanciate_delayed(self.content)
304329
with open(self.out_file.get_path(), "w") as f:
305-
if isinstance(content, str):
306-
f.write(content)
307-
elif isinstance(content, dict):
308-
for key, val in content.items():
309-
f.write(f"{key}: {val}\n")
310-
elif isinstance(content, Iterable):
311-
for line in content:
312-
f.write(f"{line}\n")
313-
else:
314-
raise NotImplementedError
330+
self.write_content_to_file(f)
331+
332+
333+
class WriteToCsvFileJob(WriteToTextFileJob):
334+
"""
335+
Write a given content into a csv file, one entry per line.
336+
337+
This job only supports dictionaries as input type. Each key/value pair is written as `<key><delimiter><value>`.
338+
"""
339+
340+
__sis_hash_exclude__ = {} # It was filled in the base class, but it's not needed anymore since this is a new job.
341+
342+
def __init__(
343+
self,
344+
content: Dict[str, Union[str, List[str]]],
345+
*,
346+
out_name: str = "file.txt",
347+
delimiter: str = "\t",
348+
):
349+
"""
350+
:param content: input which will be written into a text file
351+
:param out_name: user specific file name for the output file
352+
:param delimiter: Delimiter used to separate the different entries.
353+
"""
354+
super().__init__(content, out_name)
355+
356+
self.delimiter = delimiter
357+
358+
def write_content_to_file(self, file_handler: IOBase):
359+
"""
360+
Writes the input contents (from `self.content`) into the file provided as parameter as a csv file.
361+
362+
:param file_handler: Open file to write the contents of `self.content` to.
363+
"""
364+
csv_writer = csv.writer(file_handler, delimiter=self.delimiter)
365+
content = util.instanciate_delayed(self.content)
366+
if isinstance(content, dict):
367+
for key, val in content.items():
368+
if isinstance(val, list):
369+
csv_writer.writerow((key, *val))
370+
else:
371+
csv_writer.writerow((key, val))
372+
else:
373+
raise NotImplementedError("Content of unknown type different from (str, dict, Iterable).")
315374

316375

317376
class SplitTextFileJob(Job):

0 commit comments

Comments
 (0)