-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_processing.py
More file actions
36 lines (26 loc) · 982 Bytes
/
data_processing.py
File metadata and controls
36 lines (26 loc) · 982 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from functools import partial
from transformers import AutoTokenizer
import re
from prompt_formatting import create_prompt_formats
def clean_string(s):
s = s.replace('\n', '')
s = re.sub(r'[^\w\s]', '', s)
return s
def preprocess_batch(batch, tokenizer, max_length):
return tokenizer(
batch["text"],
max_length = max_length,
truncation = True,
)
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed, dataset: str):
print("Preprocessing dataset...")
dataset = dataset.map(create_prompt_formats)
_preprocessing_function = partial(preprocess_batch, max_length = max_length, tokenizer = tokenizer)
dataset = dataset.map(
_preprocessing_function,
batched = True,
remove_columns = ["Instruction", "text", "category", "Text"],
)
dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)
dataset = dataset.shuffle(seed = seed)
return dataset