-
Notifications
You must be signed in to change notification settings - Fork 457
Expand file tree
/
Copy pathpeoples_speech.py
More file actions
90 lines (71 loc) · 3.67 KB
/
peoples_speech.py
File metadata and controls
90 lines (71 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict
from datasets.formatting.formatting import LazyRow
from loguru import logger
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.transformers.finetune.data.base import get_columns
from llmcompressor.typing import DatasetType, Processor
if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
@TextGenerationDataset.register(name="peoples_speech")
class PeoplesSpeech(TextGenerationDataset):
"""
ML Commons People's Speech audio dataset
Unfortunately, due to the specialized nature of audio model preprocessing, some
model specific code must be defined here. This dataset has been tested with the
WhisperForConditionalGeneration and Qwen2AudioForConditionalGeneration model classes
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""
def __init__(self, dataset_args: "DataArgs", split: str, processor: Processor):
dataset_args = deepcopy(dataset_args)
dataset_args.dataset = "MLCommons/peoples_speech"
dataset_args.dataset_config_name = "test"
if not dataset_args.overwrite_cache:
logger.warning(
"Because audio processors are more complex, dataset mapping functions "
"vary with model architecture and their results cannot be cached. "
"Setting overwrite_cache=True"
)
dataset_args.overwrite_cache = True
self.processor_type = processor.__class__.__name__
super().__init__(dataset_args=dataset_args, split=split, processor=processor)
def dataset_template(self, example):
audio = example["audio"]["array"]
sampling_rate = example["audio"]["sampling_rate"]
if self.processor_type == "Qwen2AudioProcessor":
messages = [
{"role": "user", "content": [{"audio": None}]},
{"role": "user", "content": [{"text": "What did the person say?"}]},
]
text = self.processor.apply_chat_template(messages)
return {"audios": [audio], "sampling_rate": sampling_rate, "text": text}
else:
# chat template decoder ids are appended later by self.processor.__call__
text = " " + example["text"].capitalize()
return {"audio": audio, "sampling_rate": sampling_rate, "text": text}
def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
if self.processor_type == "WhisperProcessor":
tokenizer_args = ["audio", "sampling_rate", "text"]
column_names = get_columns(dataset)
return dataset.remove_columns(list(set(column_names) - set(tokenizer_args)))
else:
return super().filter_tokenizer_args(dataset)
def tokenize(self, data: LazyRow) -> Dict[str, Any]:
if self.processor_type == "WhisperProcessor":
inputs = self.processor(
audio=data["audio"],
sampling_rate=data["sampling_rate"],
text=data["text"],
add_special_tokens=True,
return_tensors="pt",
)
# TODO: inputs["input_features"] is a float dtype, which may conflict with
# the dtype of the model. Add logic to in data pipeline to move inputs to
# the matching model device and dtype
inputs["decoder_input_ids"] = inputs["labels"]
del inputs["labels"]
return inputs
else:
return super().tokenize(data)