-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
245 lines (219 loc) · 9.12 KB
/
evaluate.py
File metadata and controls
245 lines (219 loc) · 9.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import json
import os
import random
import argparse
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletion
import asyncio
import uvloop
### python process_dataset.py --type KG --dataset_dir ./dataset/KG_random_42_1k.json --api_key 0 --api_url http://xxxxx --model llama3 --async_mode
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--type', default="KG", type=str)
parser.add_argument('--dataset_dir', default="dataset/KG_random_42_1k.json", type=str)
parser.add_argument('--api_key', default="0", type=str) #
parser.add_argument('--api_url', default="http://0.0.0.0:8000/v1", type=str) #
parser.add_argument('--model', default="llama3", type=str) #,
parser.add_argument('--async_mode', action='store_true', help='Enable async mode')
parser.add_argument('--evaluate_only', action='store_true', help='Enable evaluation only')
args = parser.parse_args()
os.makedirs('./output', exist_ok=True)
output_dir = f'./output/'
if args.api_key and args.api_url:
if not args.async_mode:
key=args.api_key
url=args.api_url
client = OpenAI(
api_key=key, base_url=url, max_retries=10
)
else:
key=args.api_key
url=args.api_url
client = AsyncOpenAI(
api_key=key, base_url=url, max_retries=5
)
else:
raise ValueError("Please input the api_key and api_url")
id2ans = dict()
id2numpos = dict()
############################# prompt #################################
with open(args.dataset_dir) as f:
samples = json.load(f)
with open(f'{args.type}.json') as f:
original = json.load(f)
question_id = []
prompts = []
for item in original:
if args.type=="KG":
id = item['id']
length = len(item['positive_triples'])
if length>=5: length =5
id2numpos[id] = length
if args.type=="Table":
id = item['id']
length = len(item['positive_rows'])
if length>=5: length =5
id2numpos[id] = length
if args.type=="KG+Text":
id = item['id']
length = len(item['positive_triples'])+len(item['positive_texts'])
if length>=5: length =5
id2numpos[id] = length
if args.type=="Table+Text":
id = item['id']
length = len(item['positive_rows'])+len(item['positive_texts'])
if length>=5: length =5
id2numpos[id] = length
for sample in samples:
if args.type=="KG":
id = sample['id']
question_id.append(id)
question = sample['question']
answer = sample['answer']
data = sample['data']
id2ans[id] = answer
prompt = "### Triples:\n"+'\n'.join(data)+f"Task Description: Based on the triples provided above, please answer the following questions.\n ### Question: {question}\n"+"Return the final result as JSON in the format \{\"answer\": <YOUR ANSWER STRING LIST>\} in the last line."
prompts.append(prompt)
if args.type=="Table":
id = sample['id']
question_id.append(id)
question = sample['question']
answer = sample['answer']
data = sample['data']
id2ans[id] = answer
prompt = "### Table:\n"+'\n'.join(data)+f"Task Description: Please look at the table, and then answer the following questions.\n ### Question: {question}\n"+"Return the final result as JSON in the format \{\"answer\": <YOUR ANSWER STRING LIST>\} in the last line."
prompts.append(prompt)
if args.type=="KG+Text":
id = sample['id']
question_id.append(id)
question = sample['question']
answer = sample['answer']
KG_data = sample['KG_data']
Text_data = sample['Text_data']
id2ans[id] = answer
prompt = "### Triples:\n"+'\n'.join(KG_data)+"\n### Texts:\n"+'\n'.join(Text_data)+f"Task Description: Based on the triples and texts provided above, please answer the specific product for following questions.\n ### Question: {question}\n"+"Return the final result as JSON in the format \{\"answer\": <YOUR ANSWER STRING LIST>\} in the last line."
prompts.append(prompt)
if args.type=="Table+Text":
id = sample['id']
question_id.append(id)
question = sample['question']
answer = sample['answer']
Table_data = sample['Table_data']
Text_data = sample['Text_data']
id2ans[id] = answer
prompt = "### Table:\n"+'\n'.join(Table_data)+"\n### Texts:\n"+'\n'.join(Text_data)+f"Task Description: Based on the table and texts provided above, please answer the specific product for following questions.\n ### Question: {question}\n"+"Return the final result as JSON in the format \{\"answer\": <YOUR ANSWER STRING LIST>\} in the last line."
prompts.append(prompt)
# question_id = question_id[168:]
# prompts = prompts[168:]
print("############################# Finish prompts #################################")
############################## 异步调用api ##############################
async def translate(id,prompt,file):
input =[{"role": "user", "content": prompt}]
response = await client.chat.completions.create(
model=args.model,
messages=input,
)
print(id)
output=response.choices[0].message.content
json_str = json.dumps({"id":id,"input":prompt,"output":output}, ensure_ascii=False) # 将JSON对象转为字符串
file.write(json_str + '\n') # 写入文件并换行
return response
async def test():
tasks = []
os.makedirs(output_dir, exist_ok=True)
file = open(f"{output_dir}/{args.model}_{args.dataset_dir.split('/')[-1].split('.json')[0]}.jsonl","a+")
batch_size = 10
for i in range(len(question_id)//batch_size):
try:
batch = prompts[i*batch_size:i*batch_size+batch_size]
except:
batch = prompts[i*batch_size:]
for id,prompt in enumerate(batch):
task = asyncio.create_task(
translate(
question_id[id+i*batch_size],prompt,file
)
)
tasks.append(task)
await asyncio.gather(*tasks)
file.close()
if args.async_mode and not args.evaluate_only:
uvloop.install()
asyncio.run(test())
############################# 直接调用api ##############################
elif not args.evaluate_only:
for id,prompt in enumerate(prompts):
# if id<808:
# continue
input =[{"role": "user", "content": prompt}]
print(id)
try:
completion = client.chat.completions.create(
model=args.model,
messages=input,
max_tokens=4096,
)
except Exception as e:
continue
output=completion.choices[0].message.content
json_str = json.dumps({"id":question_id[id],"input":prompt, "output":output},ensure_ascii=False)
os.makedirs(output_dir, exist_ok=True)
file = open(f"{output_dir}/{args.model}_{args.dataset_dir.split('/')[-1].split('.json')[0]}.jsonl","a+",encoding='utf-8')
file.write(json_str + '\n')
############################# evaluate ##############################
### information integration
def f1_score(list1, list2):
if type(list1) == int:
list1 = [f"{list1}"]
set1, set2 = set(list1), set(list2)
intersection = set1 & set2
precision = len(intersection) / len(set1) if len(set1) > 0 else 0
recall = len(intersection) / len(set2) if len(set2) > 0 else 0
if precision + recall == 0:
return 0
f1 = 2 * (precision * recall) / (precision + recall)
return f1
with open(f"{output_dir}/{args.model}_{args.dataset_dir.split('/')[-1].split('.json')[0]}.jsonl") as f:
results = [json.loads(line) for line in f]
### compute f1 score
f1 = 0
syntax_error = 0
api_failure = 0
numpos2ans = dict()
for result in results:
id = result['id']
if result['output'] == None:continue
output = ('{' + result['output'].split('{')[-1].split('}')[0] + '}').replace("\\","")
print(output)
if "answer" not in output:
syntax_error += 1
continue
try:
output = json.loads(output)['answer']
except:
syntax_error += 1
continue
### 如果是int或者float转化为[]
if type(output)!=list:
output = [f"{output}"]
elif type(output)!=list:
syntax_error += 1
continue
ans = id2ans[id]
if not output:
api_failure += 1
continue
f1 += f1_score(output,ans)
length = id2numpos[id]
if length not in numpos2ans: numpos2ans[length]=[f1_score(output,ans)]
else: numpos2ans[length].append(f1_score(output,ans))
# if f1_score(output,ans)-1.0<0:
# with open("table.txt","a")as f:
# f.write(f"{id}\n")
f1 = f1/len(results)
print(f"num of sample is {len(results)}")
print(f"num of syntax error is {syntax_error}")
print(f"num of api failure is {api_failure}")
print(f"f1 score is {f1}")
print("**** information integration ****")
for key in numpos2ans:
print(f"num of positive samples is {key} f1 score is {sum(numpos2ans[key])/len(numpos2ans[key])}")