Skip to content

Commit a242be9

Browse files
committed
add annotator
1 parent 2f890b4 commit a242be9

File tree

3 files changed

+291
-1
lines changed

3 files changed

+291
-1
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,12 @@ If everything works correctly, the output should be:
218218
"lf_accuracy": 0.2334609075997813
219219
}
220220
```
221+
222+
223+
## Annotation
224+
225+
In addition to the raw data dump, we also release an optional annotation script that annotates WikiSQL using [Stanford CoreNLP](https://stanfordnlp.github.io/CoreNLP/).
226+
The `annotate.py` script will annotate the query, question, and SQL table, as well as a sequence to sequence construction of the input and output for convenience of using Seq2Seq models.
227+
To use `annotate.py`, you must set up the CoreNLP python client using [Stanford Stanza](https://github.com/stanfordnlp/stanza).
228+
Note that the sequence output contain symbols to delineate the boundaries of fields.
229+
In `lib/query.py` you will also find accompanying functions to reconstruct a query given a sequence output in the annotated format.

annotate.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#!/usr/bin/env python
2+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
3+
import os
4+
import records
5+
import ujson as json
6+
from stanza.nlp.corenlp import CoreNLPClient
7+
from tqdm import tqdm
8+
import copy
9+
from lib.common import count_lines, detokenize
10+
from lib.query import Query
11+
12+
13+
client = None
14+
15+
16+
def annotate(sentence, lower=True):
17+
global client
18+
if client is None:
19+
client = CoreNLPClient(default_annotators='ssplit,tokenize'.split(','))
20+
words, gloss, after = [], [], []
21+
for s in client.annotate(sentence):
22+
for t in s:
23+
words.append(t.word)
24+
gloss.append(t.originalText)
25+
after.append(t.after)
26+
if lower:
27+
words = [w.lower() for w in words]
28+
return {
29+
'gloss': gloss,
30+
'words': words,
31+
'after': after,
32+
}
33+
34+
35+
def annotate_example(example, table):
36+
ann = {'table_id': example['table_id']}
37+
ann['question'] = annotate(example['question'])
38+
ann['table'] = {
39+
'header': [annotate(h) for h in table['header']],
40+
}
41+
ann['query'] = sql = copy.deepcopy(example['sql'])
42+
for c in ann['query']['conds']:
43+
c[-1] = annotate(str(c[-1]))
44+
45+
q1 = 'SYMSELECT SYMAGG {} SYMCOL {}'.format(Query.agg_ops[sql['agg']], table['header'][sql['sel']])
46+
q2 = ['SYMCOL {} SYMOP {} SYMCOND {}'.format(table['header'][col], Query.cond_ops[op], detokenize(cond)) for col, op, cond in sql['conds']]
47+
if q2:
48+
q2 = 'SYMWHERE ' + ' SYMAND '.join(q2) + ' SYMEND'
49+
else:
50+
q2 = 'SYMEND'
51+
inp = 'SYMSYMS {syms} SYMAGGOPS {aggops} SYMCONDOPS {condops} SYMTABLE {table} SYMQUESTION {question}'.format(
52+
syms=' '.join(['SYM' + s for s in Query.syms]),
53+
table=' '.join(['SYMCOL ' + s for s in table['header']]),
54+
question=example['question'],
55+
aggops=' '.join([s for s in Query.agg_ops]),
56+
condops=' '.join([s for s in Query.cond_ops]),
57+
)
58+
ann['seq_input'] = annotate(inp)
59+
out = '{q1} {q2}'.format(q1=q1, q2=q2) if q2 else q1
60+
ann['seq_output'] = annotate(out)
61+
ann['where_output'] = annotate(q2)
62+
assert 'symend' in ann['seq_output']['words']
63+
assert 'symend' in ann['where_output']['words']
64+
return ann
65+
66+
67+
def is_valid_example(e):
68+
if not all([h['words'] for h in e['table']['header']]):
69+
return False
70+
headers = [detokenize(h).lower() for h in e['table']['header']]
71+
if len(headers) != len(set(headers)):
72+
return False
73+
input_vocab = set(e['seq_input']['words'])
74+
for w in e['seq_output']['words']:
75+
if w not in input_vocab:
76+
print('query word "{}" is not in input vocabulary.\n{}'.format(w, e['seq_input']['words']))
77+
return False
78+
input_vocab = set(e['question']['words'])
79+
for col, op, cond in e['query']['conds']:
80+
for w in cond['words']:
81+
if w not in input_vocab:
82+
print('cond word "{}" is not in input vocabulary.\n{}'.format(w, e['question']['words']))
83+
return False
84+
return True
85+
86+
87+
if __name__ == '__main__':
88+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
89+
parser.add_argument('--din', default='data', help='data directory')
90+
parser.add_argument('--dout', default='annotated', help='output directory')
91+
args = parser.parse_args()
92+
93+
if not os.path.isdir(args.dout):
94+
os.makedirs(args.dout)
95+
96+
for split in ['train', 'dev', 'test']:
97+
fsplit = os.path.join(args.din, split) + '.jsonl'
98+
ftable = os.path.join(args.din, split) + '.tables.jsonl'
99+
fout = os.path.join(args.dout, split) + '.jsonl'
100+
101+
print('annotating {}'.format(fsplit))
102+
with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo:
103+
print('loading tables')
104+
tables = {}
105+
for line in tqdm(ft, total=count_lines(ftable)):
106+
d = json.loads(line)
107+
tables[d['id']] = d
108+
print('loading examples')
109+
n_written = 0
110+
for line in tqdm(fs, total=count_lines(fsplit)):
111+
d = json.loads(line)
112+
a = annotate_example(d, tables[d['table_id']])
113+
if not is_valid_example(a):
114+
raise Exception(str(a))
115+
116+
gold = Query.from_tokenized_dict(a['query'])
117+
reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True)
118+
if gold.lower() != reconstruct.lower():
119+
raise Exception ('Expected:\n{}\nGot:\n{}'.format(gold, reconstruct))
120+
fo.write(json.dumps(a) + '\n')
121+
n_written += 1
122+
print('wrote {} examples'.format(n_written))

lib/query.py

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from copy import deepcopy
44
import re
55

6+
67
re_whitespace = re.compile(r'\s+', flags=re.UNICODE)
78

89

910
class Query:
1011

1112
agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
1213
cond_ops = ['=', '>', '<', 'OP']
13-
syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']
14+
syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS', 'END']
1415

1516
def __init__(self, sel_index, agg_index, conditions=tuple()):
1617
self.sel_index = sel_index
@@ -68,3 +69,161 @@ def from_generated_dict(cls, d):
6869
end = len(val['words'])
6970
conds.append([col, op, detokenize(val)])
7071
return cls(d['sel'], d['agg'], conds)
72+
73+
@classmethod
74+
def from_sequence(cls, sequence, table, lowercase=True):
75+
sequence = deepcopy(sequence)
76+
if 'symend' in sequence['words']:
77+
end = sequence['words'].index('symend')
78+
for k, v in sequence.items():
79+
sequence[k] = v[:end]
80+
terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])]
81+
headers = [detokenize(h) for h in table['header']]
82+
83+
# lowercase everything and truncate sequence
84+
if lowercase:
85+
headers = [h.lower() for h in headers]
86+
for i, t in enumerate(terms):
87+
for k, v in t.items():
88+
t[k] = v.lower()
89+
headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers]
90+
91+
# get select
92+
if 'symselect' != terms.pop(0)['word']:
93+
raise Exception('Missing symselect operator')
94+
95+
# get aggregation
96+
if 'symagg' != terms.pop(0)['word']:
97+
raise Exception('Missing symagg operator')
98+
agg_op = terms.pop(0)['word']
99+
100+
if agg_op == 'symcol':
101+
agg_op = ''
102+
else:
103+
if 'symcol' != terms.pop(0)['word']:
104+
raise Exception('Missing aggregation column')
105+
try:
106+
agg_op = cls.agg_ops.index(agg_op.upper())
107+
except Exception as e:
108+
raise Exception('Invalid agg op {}'.format(agg_op))
109+
110+
def find_column(name):
111+
return headers_no_whitespcae.index(re.sub(re_whitespace, '', name))
112+
113+
def flatten(tokens):
114+
ret = {'words': [], 'after': [], 'gloss': []}
115+
for t in tokens:
116+
ret['words'].append(t['word'])
117+
ret['after'].append(t['after'])
118+
ret['gloss'].append(t['gloss'])
119+
return ret
120+
where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere']
121+
where_index = where_index[0] if where_index else len(terms)
122+
flat = flatten(terms[:where_index])
123+
try:
124+
agg_col = find_column(detokenize(flat))
125+
except Exception as e:
126+
raise Exception('Cannot find aggregation column {}'.format(flat['words']))
127+
where_terms = terms[where_index+1:]
128+
129+
# get conditions
130+
conditions = []
131+
while where_terms:
132+
t = where_terms.pop(0)
133+
flat = flatten(where_terms)
134+
if t['word'] != 'symcol':
135+
raise Exception('Missing conditional column {}'.format(flat['words']))
136+
try:
137+
op_index = flat['words'].index('symop')
138+
col_tokens = flatten(where_terms[:op_index])
139+
except Exception as e:
140+
raise Exception('Missing conditional operator {}'.format(flat['words']))
141+
cond_op = where_terms[op_index+1]['word']
142+
try:
143+
cond_op = cls.cond_ops.index(cond_op.upper())
144+
except Exception as e:
145+
raise Exception('Invalid cond op {}'.format(cond_op))
146+
try:
147+
cond_col = find_column(detokenize(col_tokens))
148+
except Exception as e:
149+
raise Exception('Cannot find conditional column {}'.format(col_tokens['words']))
150+
try:
151+
val_index = flat['words'].index('symcond')
152+
except Exception as e:
153+
raise Exception('Cannot find conditional value {}'.format(flat['words']))
154+
155+
where_terms = where_terms[val_index+1:]
156+
flat = flatten(where_terms)
157+
val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms)
158+
cond_val = detokenize(flatten(where_terms[:val_end_index]))
159+
conditions.append([cond_col, cond_op, cond_val])
160+
where_terms = where_terms[val_end_index+1:]
161+
q = cls(agg_col, agg_op, conditions)
162+
return q
163+
164+
@classmethod
165+
def from_partial_sequence(cls, agg_col, agg_op, sequence, table, lowercase=True):
166+
sequence = deepcopy(sequence)
167+
if 'symend' in sequence['words']:
168+
end = sequence['words'].index('symend')
169+
for k, v in sequence.items():
170+
sequence[k] = v[:end]
171+
terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])]
172+
headers = [detokenize(h) for h in table['header']]
173+
174+
# lowercase everything and truncate sequence
175+
if lowercase:
176+
headers = [h.lower() for h in headers]
177+
for i, t in enumerate(terms):
178+
for k, v in t.items():
179+
t[k] = v.lower()
180+
headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers]
181+
182+
def find_column(name):
183+
return headers_no_whitespcae.index(re.sub(re_whitespace, '', name))
184+
185+
def flatten(tokens):
186+
ret = {'words': [], 'after': [], 'gloss': []}
187+
for t in tokens:
188+
ret['words'].append(t['word'])
189+
ret['after'].append(t['after'])
190+
ret['gloss'].append(t['gloss'])
191+
return ret
192+
where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere']
193+
where_index = where_index[0] if where_index else len(terms)
194+
where_terms = terms[where_index+1:]
195+
196+
# get conditions
197+
conditions = []
198+
while where_terms:
199+
t = where_terms.pop(0)
200+
flat = flatten(where_terms)
201+
if t['word'] != 'symcol':
202+
raise Exception('Missing conditional column {}'.format(flat['words']))
203+
try:
204+
op_index = flat['words'].index('symop')
205+
col_tokens = flatten(where_terms[:op_index])
206+
except Exception as e:
207+
raise Exception('Missing conditional operator {}'.format(flat['words']))
208+
cond_op = where_terms[op_index+1]['word']
209+
try:
210+
cond_op = cls.cond_ops.index(cond_op.upper())
211+
except Exception as e:
212+
raise Exception('Invalid cond op {}'.format(cond_op))
213+
try:
214+
cond_col = find_column(detokenize(col_tokens))
215+
except Exception as e:
216+
raise Exception('Cannot find conditional column {}'.format(col_tokens['words']))
217+
try:
218+
val_index = flat['words'].index('symcond')
219+
except Exception as e:
220+
raise Exception('Cannot find conditional value {}'.format(flat['words']))
221+
222+
where_terms = where_terms[val_index+1:]
223+
flat = flatten(where_terms)
224+
val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms)
225+
cond_val = detokenize(flatten(where_terms[:val_end_index]))
226+
conditions.append([cond_col, cond_op, cond_val])
227+
where_terms = where_terms[val_end_index+1:]
228+
q = cls(agg_col, agg_op, conditions)
229+
return q

0 commit comments

Comments
 (0)