1919import os
2020
2121# Third Party
22- from datasets import DatasetDict , IterableDataset , IterableDatasetDict
22+ from datasets import DatasetDict , Features , IterableDataset , IterableDatasetDict
2323from PIL import Image
2424import yaml
2525
@@ -70,15 +70,15 @@ def resolve_iterable_dataset_features(data: IterableDataset):
7070 return data
7171
7272
73- def __get_dataset_features (d , default_split = "train" ):
73+ def __get_dataset_features (d , default_split : str = "train" ) -> Features :
7474 return (
7575 d [default_split ].features
76- if isinstance (d , (DatasetDict or IterableDatasetDict ))
76+ if isinstance (d , (DatasetDict , IterableDatasetDict ))
7777 else d .features
7878 )
7979
8080
81- def _maybe_cast_columns (datasets , default_split = "train" ):
81+ def _maybe_cast_columns (datasets : list , default_split : str = "train" ) -> None :
8282 """
8383 Given list of datasets, try casting datasets to same features.
8484 Assumes that the datasets are aligned in terms of columns which
@@ -95,7 +95,7 @@ def _maybe_cast_columns(datasets, default_split="train"):
9595 datasets [i ] = datasets [i ].cast (features )
9696
9797
98- def _validate_mergeable_datasets (datasets , default_split = "train" ):
98+ def _validate_mergeable_datasets (datasets : list , default_split : str = "train" ) -> None :
9999 """Given list of datasets, validate if all datasets have same type and number of columns."""
100100 if len (datasets ) <= 1 :
101101 return
@@ -122,16 +122,16 @@ def _validate_mergeable_datasets(datasets, default_split="train"):
122122 )
123123
124124
125- def maybe_align_datasets (datasets ) :
125+ def maybe_align_datasets (datasets : list ) -> None :
126126 """
127127 Given list of datasets
128128 1. validate if all datasets have same type and number of columns.
129129 2. try casting dataset columns to same value to ensure mergability
130130 """
131131 try :
132- for d in datasets :
132+ for i , d in enumerate ( datasets ) :
133133 if isinstance (d , IterableDataset ):
134- d = resolve_iterable_dataset_features (d )
134+ datasets [ i ] = resolve_iterable_dataset_features (d )
135135
136136 _validate_mergeable_datasets (datasets )
137137 _maybe_cast_columns (datasets )
0 commit comments