Skip to content

Commit 5c8830c

Browse files
Signed-off-by: Harikrishnan Balagopal <harikrishmenon@gmail.com>
1 parent 00781fc commit 5c8830c

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

tuning/data/data_processors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,8 @@ def _process_dataset_configs(
470470
final_datasets[k].append(v)
471471

472472
# Ensure again datasets are aligned before interleaving or concatenating
473-
maybe_align_datasets(final_datasets)
473+
for v in final_datasets.values():
474+
maybe_align_datasets(v)
474475

475476
if sample_datasets:
476477
strategy = self.processor_config.sampling_stopping_strategy

tuning/data/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020

2121
# Third Party
22-
from datasets import DatasetDict, IterableDataset, IterableDatasetDict
22+
from datasets import DatasetDict, Features, IterableDataset, IterableDatasetDict
2323
from PIL import Image
2424
import yaml
2525

@@ -70,15 +70,15 @@ def resolve_iterable_dataset_features(data: IterableDataset):
7070
return data
7171

7272

73-
def __get_dataset_features(d, default_split="train"):
73+
def __get_dataset_features(d, default_split: str = "train") -> Features:
7474
return (
7575
d[default_split].features
76-
if isinstance(d, (DatasetDict or IterableDatasetDict))
76+
if isinstance(d, (DatasetDict, IterableDatasetDict))
7777
else d.features
7878
)
7979

8080

81-
def _maybe_cast_columns(datasets, default_split="train"):
81+
def _maybe_cast_columns(datasets: list, default_split: str = "train") -> None:
8282
"""
8383
Given list of datasets, try casting datasets to same features.
8484
Assumes that the datasets are aligned in terms of columns which
@@ -95,7 +95,7 @@ def _maybe_cast_columns(datasets, default_split="train"):
9595
datasets[i] = datasets[i].cast(features)
9696

9797

98-
def _validate_mergeable_datasets(datasets, default_split="train"):
98+
def _validate_mergeable_datasets(datasets: list, default_split: str = "train") -> None:
9999
"""Given list of datasets, validate if all datasets have same type and number of columns."""
100100
if len(datasets) <= 1:
101101
return
@@ -122,16 +122,16 @@ def _validate_mergeable_datasets(datasets, default_split="train"):
122122
)
123123

124124

125-
def maybe_align_datasets(datasets):
125+
def maybe_align_datasets(datasets: list) -> None:
126126
"""
127127
Given list of datasets
128128
1. validate if all datasets have same type and number of columns.
129129
2. try casting dataset columns to same value to ensure mergability
130130
"""
131131
try:
132-
for d in datasets:
132+
for i, d in enumerate(datasets):
133133
if isinstance(d, IterableDataset):
134-
d = resolve_iterable_dataset_features(d)
134+
datasets[i] = resolve_iterable_dataset_features(d)
135135

136136
_validate_mergeable_datasets(datasets)
137137
_maybe_cast_columns(datasets)

0 commit comments

Comments
 (0)