forked from zhaijianyang/MQL4GRec
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcollator.py
More file actions
executable file
·112 lines (82 loc) · 3.43 KB
/
collator.py
File metadata and controls
executable file
·112 lines (82 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import copy
import argparse
from dataclasses import dataclass
import transformers
import math
from torch.utils.data import Sampler
import torch.distributed as dist
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, T5Tokenizer, T5Config, T5ForConditionalGeneration
class Collator(object):
def __init__(self, args, tokenizer):
self.args = args
self.only_train_response = args.only_train_response
self.tokenizer = tokenizer
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = 0
# print(self.tokenizer.model_max_length)
def __call__(self, batch):
input_texts = [d["input_ids"] for d in batch]
label_texts = [d["labels"] for d in batch]
inputs = self.tokenizer(input_texts,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True)
labels = self.tokenizer(label_texts,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True)
inputs['labels'] = labels['input_ids']
inputs['labels'][inputs['labels'] == self.tokenizer.pad_token_id] = -100
# print(inputs.input_ids[0])
# print(inputs.labels[0])
return inputs
class TestCollator(object):
def __init__(self, args, tokenizer):
self.args = args
self.tokenizer = tokenizer
self.prefix_token = vars(args).get('prefix_token', '')
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = 0
if isinstance(self.tokenizer, LlamaTokenizer):
# Allow batched inference
self.tokenizer.padding_side = "left"
def __call__(self, batch):
input_texts = [d["input_ids"] + self.prefix_token for d in batch]
targets = [d["labels"] for d in batch]
inputs = self.tokenizer(
text=input_texts,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True,
)
return (inputs, targets)
class TestCollatorSave(object):
def __init__(self, args, tokenizer):
self.args = args
self.tokenizer = tokenizer
self.prefix_token = vars(args).get('prefix_token', '')
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = 0
if isinstance(self.tokenizer, LlamaTokenizer):
# Allow batched inference
self.tokenizer.padding_side = "left"
def __call__(self, batch):
input_texts = [d["input_ids"] + self.prefix_token for d in batch]
targets = [d["labels"] for d in batch]
users = [d["label"] for d in batch]
inputs = self.tokenizer(
text=input_texts,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True,
)
return (inputs, targets, users)