-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathgeneration.py
More file actions
115 lines (85 loc) · 3.12 KB
/
generation.py
File metadata and controls
115 lines (85 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import time
import torch
import transformers
def generate_stream(model, tokenizer, prompt: str, device, max_new_tokens: int, context_len: int, echo: bool = False, stream_interval=2):
stop_token_ids = [model.config.eos_token_id]
device = model.device
inputs = tokenizer(prompt)
lhs_tokens = torch.tensor(inputs.input_ids, dtype=torch.int64, device=device).unsqueeze(0)
past_kvs = None
output_ids = list(inputs.input_ids)
input_echo_len = len(output_ids)
# check max_new_tokens
remain_tokens = context_len - input_echo_len
max_new_tokens = min(remain_tokens, max_new_tokens)
for i in range(max_new_tokens):
with torch.no_grad():
lhs_results = model(lhs_tokens, past_key_values=past_kvs, use_cache=True)
logits = lhs_results.logits
past_kvs = lhs_results.past_key_values
# greedy search
lhs_tokens = torch.argmax(
lhs_results.logits[:, -1, :], dim=1, keepdim=True)
token = lhs_tokens[0].item()
output_ids.append(token)
if token in stop_token_ids:
stoped = True
else:
stoped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stoped:
if echo:
tmp_output_ids = output_ids
else:
tmp_output_ids = output_ids[input_echo_len:]
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=True,
clean_up_tokenization_spaces=True
)
yield {
'text': output,
}
if stoped:
break
yield {
'text': output
}
def generate(model, tokenizer, prompt: str, max_new_tokens:int, context_len: int, echo: bool=False):
stop_token_ids = [model.config.eos_token_id]
device = model.device
inputs = tokenizer(prompt)
lhs_tokens = torch.tensor(inputs.input_ids, dtype=torch.int64, device=device).unsqueeze(0)
past_kvs = None
output_ids = list(inputs.input_ids)
input_echo_len = len(output_ids)
# check max_new_tokens
remain_tokens = context_len - input_echo_len
max_new_tokens = min(remain_tokens, max_new_tokens)
for i in range(max_new_tokens):
with torch.no_grad():
lhs_results = model(lhs_tokens, past_key_values=past_kvs, use_cache=True)
logits = lhs_results.logits
past_kvs = lhs_results.past_key_values
# greedy search
lhs_tokens = torch.argmax(
lhs_results.logits[:, -1, :], dim=1, keepdim=True)
token = lhs_tokens[0].item()
output_ids.append(token)
if token in stop_token_ids:
stoped = True
else:
stoped = False
if stoped:
break
if echo:
tmp_output_ids = output_ids
else:
tmp_output_ids = output_ids[input_echo_len:]
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=True,
clean_up_tokenization_spaces=True
)
return {'text': output}