-
Notifications
You must be signed in to change notification settings - Fork 458
Expand file tree
/
Copy pathcnn_dailymail.py
More file actions
35 lines (26 loc) · 1.23 KB
/
cnn_dailymail.py
File metadata and controls
35 lines (26 loc) · 1.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from copy import deepcopy
from typing import TYPE_CHECKING
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.typing import Processor
if TYPE_CHECKING:
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
@TextGenerationDataset.register(name="cnn_dailymail")
class CNNDailyMailDataset(TextGenerationDataset):
"""
Text generation class for the CNN/DailyMail dataset
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""
SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"
super().__init__(data_args=data_args, split=split, processor=processor)
def dataset_template(self, sample):
return {
"text": self.SAMPLE_TEMPLATE.format(
article=sample["article"], highlights=sample["highlights"]
)
}