-
Notifications
You must be signed in to change notification settings - Fork 458
Expand file tree
/
Copy pathwikitext.py
More file actions
30 lines (23 loc) · 1007 Bytes
/
wikitext.py
File metadata and controls
30 lines (23 loc) · 1007 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from copy import deepcopy
from typing import TYPE_CHECKING
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.typing import Processor
if TYPE_CHECKING:
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
@TextGenerationDataset.register(name="wikitext")
class WikiTextDataset(TextGenerationDataset):
"""
Child text generation class for the Open Platypus dataset
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "Salesforce/wikitext"
data_args.text_column = "text"
super().__init__(
data_args=data_args,
split=split,
processor=processor,
)