Skip to content
253 changes: 251 additions & 2 deletions datasets/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
https://huggingface.co/docs/datasets/
"""

from typing import Optional, Any, Union
from sisyphus import Job, Task, gs
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Any, Union, Callable, Sequence, Dict
from sisyphus import Job, Task, Path, gs
from sisyphus.delayed_ops import DelayedBase

from i6_core.util import instanciate_delayed

if TYPE_CHECKING:
from datasets import Dataset, DatasetDict

TransformFuncT = Union[Callable[[DatasetDict], DatasetDict], Callable[[Dataset], Dataset]]


class DownloadAndPrepareHuggingFaceDatasetJob(Job):
"""
Expand Down Expand Up @@ -112,3 +118,246 @@ def hash(cls, kwargs):
"token": kwargs["token"],
}
return super().hash(d)


class TransformAndMapHuggingFaceDatasetJob(Job):
"""
Runs some functions (e.g. filtering, mapping, renaming columns, ...) on a HF dataset.

The map is handled with special logic, as this involves writing to disk.
We write to the work dir via cache_file_name(s).
Then we do a save_to_disk to the final output dir.
Then we clean up the work dir again.
"""

def __init__(
self,
path: Union[str, Path],
name: Optional[str] = None,
*,
load_dataset_opts: Optional[Dict[str, Any]] = None, # e.g. "split", "revision", ...
non_hashed_load_dataset_opts: Optional[Dict[str, Any]] = None, # e.g. {"num_proc": 8}
transform: Union[None, TransformFuncT, Sequence[TransformFuncT]] = None,
map_func: Optional[Callable] = None,
map_opts: Union[
None, Dict[str, Any], Callable[[Dataset], Dict[str, Any]], Callable[[DatasetDict], Dict[str, Any]]
] = None,
non_hashed_map_opts: Optional[Dict[str, Any]] = None,
num_shards: Union[None, int, Dict[str, int]] = None,
max_shard_size: Union[None, str, int] = None,
):
"""
:param path: for :func:`datasets.load_dataset`,
:func:`datasets.Dataset.load_from_disk` or :func:`datasets.DatasetDict.load_from_disk`.
We automatically detect which one to use.
:param name: for :func:`datasets.load_dataset`
:param load_dataset_opts: other options for :func:`datasets.load_dataset`
or :func:`datasets.Dataset.load_from_disk` or :func:`datasets.DatasetDict.load_from_disk`.
E.g. "split", "revision", ...
:param non_hashed_load_dataset_opts: like ``load_dataset_opts``, but not hashed.
E.g. ``{"num_proc": 8}``.
:param transform: function or list of functions to transform the dataset
((Dataset) -> Dataset or (DatasetDict) -> DatasetDict).
E.g. filtering, renaming columns, ...
:param map_func: function to map the dataset examples, or batch of examples.
This is passed to :func:`datasets.Dataset.map` or :func:`datasets.DatasetDict.map`.
None (default) means identity.
:param map_opts: further options passed :func:`datasets.Dataset.map` or :func:`datasets.DatasetDict.map`,
or a function that returns such options (e.g. depending on the dataset size).
E.g. ``{"batched": True, "batch_size": 1000}``.
:param non_hashed_map_opts: like ``map_opts``, but not hashed.
:param num_shards: how many shards to write via :func:`datasets.Dataset.save_to_disk`
or :func:`datasets.DatasetDict.save_to_disk`.
If not given, will be auto-detected based on the dataset size and ``max_shard_size``.
:param max_shard_size: maximum size of each shard.
If not given, will use ``"500MB"``.
"""
super().__init__()

if max_shard_size is not None and num_shards is not None:
raise ValueError(f"{self}: please specify either max_shard_size or num_shards, but not both.")

self.path = path
self.name = name
self.load_dataset_opts = load_dataset_opts
self.non_hashed_load_dataset_opts = non_hashed_load_dataset_opts
self.transform = transform
self.map_func = map_func
self.map_opts = map_opts
self.non_hashed_map_opts = non_hashed_map_opts
self.num_shards = num_shards
self.max_shard_size = max_shard_size

self.rqmt = {"cpu": 16, "mem": 16, "time": 12}

self.out_dir = self.output_path("dataset", directory=True)

@classmethod
def hash(cls, kwargs):
kwargs.pop("non_hashed_load_dataset_opts")
kwargs.pop("non_hashed_map_opts")
Comment on lines +197 to +198
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per the discussion in rwth-i6/sisyphus#274, should we do a copy here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, my comment from there (rwth-i6/sisyphus#274 (comment)):

But I see that there is a lot of code which currently does it that way, so I don't think we can change this now.

The (shallow) copy on kwargs would need to be done anyway in JobSingleton.__call__ if we want to fix rwth-i6/sisyphus#274. So then modifying kwargs here is safe, and consistent to most other code.

return super().hash(kwargs)

def tasks(self):
yield Task("run", resume="run", rqmt=self.rqmt)

def run(self):
import os
import shutil
from datasets import load_dataset, Dataset, DatasetDict
from datasets.utils.py_utils import convert_file_size_to_int
from datasets import config

dataset_path = instanciate_delayed(self.path)
split = None
load_dataset_opts = (instanciate_delayed(self.load_dataset_opts) or {}).copy()
load_dataset_opts.update(instanciate_delayed(self.non_hashed_load_dataset_opts) or {})
if self.name is not None:
load_dataset_opts["name"] = self.name
if "split" in load_dataset_opts:
split = load_dataset_opts["split"]
path_ext = f"{dataset_path}/{split}" if split is not None else dataset_path
ds = None
if os.path.exists(path_ext):
if os.path.isdir(path_ext):
if os.path.exists(f"{path_ext}/{config.DATASET_INFO_FILENAME}") and os.path.exists(
f"{path_ext}/{config.DATASET_STATE_JSON_FILENAME}"
):
load_dataset_opts.pop("split", None)
ds = Dataset.load_from_disk(path_ext, **load_dataset_opts)
elif os.path.exists(f"{path_ext}/{config.DATASETDICT_JSON_FILENAME}"):
load_dataset_opts.pop("split", None)
ds = DatasetDict.load_from_disk(path_ext, **load_dataset_opts)
elif path_ext.endswith(".arrow"):
load_dataset_opts.pop("split", None)
ds = Dataset.from_file(path_ext, **load_dataset_opts)

if ds is None:
# Use load_dataset.
# That can potentially download the dataset, so make sure that HF_HOME is set.
assert os.environ.get("HF_HOME"), (
"HF_HOME env var not set,"
" set this in your settings.py DEFAULT_ENVIRONMENT_SET"
" (if not CLEANUP_ENVIRONMENT, otherwise in your current env),"
" or via job.set_env"
)

ds = load_dataset(dataset_path, **load_dataset_opts)
assert isinstance(ds, (Dataset, DatasetDict))

if self.transform:
if callable(self.transform):
ds = self.transform(ds)
assert isinstance(ds, (Dataset, DatasetDict)), f"After {self.transform} got {type(ds)}"
else:
for func in self.transform:
ds = func(ds)
assert isinstance(ds, (Dataset, DatasetDict)), f"After {func} got {type(ds)}"

work_out_d = "tmp-map-output"
if os.path.exists(work_out_d):
shutil.rmtree(work_out_d)
os.makedirs(work_out_d)
map_opts = self.map_opts
if callable(map_opts):
map_opts = map_opts(ds)
map_extra_opts = {}
if self.non_hashed_map_opts and "num_proc" in self.non_hashed_map_opts:
num_proc = self.non_hashed_map_opts["num_proc"]
else:
num_proc = self.rqmt["cpu"] * 2
map_extra_opts["num_proc"] = num_proc
if self.map_func:
ds = ds.map(
self.map_func,
**(map_opts or {}),
**(self.non_hashed_map_opts or {}),
**({"cache_file_name": f"{work_out_d}/data.arrow"} if isinstance(ds, Dataset) else {}),
**(
{"cache_file_names": {k: f"{work_out_d}/data-{k}.arrow" for k in ds.keys()}}
if isinstance(ds, DatasetDict)
else {}
),
**map_extra_opts,
)

num_shards = self.num_shards
max_shard_size = self.max_shard_size or "500MB"
max_shard_size = convert_file_size_to_int(max_shard_size)
if num_shards is None:
# This code is adapted from Dataset.save_to_disk to determine the number of shards.
# We make this independent of num_proc (because num_proc is not hashed).
if isinstance(ds, DatasetDict):
# noinspection PyProtectedMember
num_shards = {k: int(ds_._estimate_nbytes() / max_shard_size) + 1 for k, ds_ in ds.items()}
elif isinstance(ds, Dataset):
# noinspection PyProtectedMember
num_shards = int(ds._estimate_nbytes() / max_shard_size) + 1
else:
raise TypeError(f"Unexpected type: {type(ds)}")

ds.save_to_disk(self.out_dir.get_path(), num_shards=num_shards, num_proc=num_proc)
del ds
shutil.rmtree(work_out_d)


class ExtractTextFromHuggingFaceDatasetJob(Job):
"""
Extract a text column from a HF dataset and write it to a gzipped text file.
"""

def __init__(
self,
path: Union[str, Path],
name: Optional[str] = None,
*,
split: Optional[str] = "train",
column_name: str = "text",
):
"""
:param path: for :func:`datasets.load_dataset`
:param name: for :func:`datasets.load_dataset`
:param split: for :func:`datasets.load_dataset`
:param column_name: name of the text column to extract
"""
super().__init__()
self.path = path
self.name = name
self.split = split
self.column_name = column_name

self.rqmt = {"cpu": 4, "mem": 8, "time": 10}

self.out_text = self.output_path("text.txt.gz")

def tasks(self):
yield Task("run", resume="run", rqmt=self.rqmt)

def run(self):
import sys
import gzip
import time
from datasets import load_dataset, Dataset

ds = load_dataset(self.path, self.name, split=self.split)
assert isinstance(ds, Dataset), f"Expected a Dataset, got {type(ds)} {ds}"
assert self.column_name in ds.column_names, f"Column name {self.column_name} not in columns {ds.column_names}"

def _hms(s):
m, s = divmod(s, 60)
h, m = divmod(m, 60)
return "%d:%02d:%02d" % (h, m, s)

size = ds.num_rows
start_time = time.monotonic()
with gzip.open(self.out_text.get_path(), "wt", encoding="utf-8") as f:
for i, item in enumerate(ds):
if (i + 1) % 10000 == 0 or i + 1 == size:
elapsed = time.monotonic() - start_time
speed = (i + 1) / elapsed if elapsed > 0 else 0
eta = (size - (i + 1)) / speed if speed > 0 else float("inf")
eta_str = _hms(eta) if eta != float("inf") else "inf"
print(f"Line {i + 1}/{size}, {((i + 1) / size * 100):.1f}%, {speed:.1f} it/s, ETA {eta_str}")
sys.stdout.flush()
f.write(item[self.column_name])
f.write("\n")