Skip to content

Commit 2f890b4

Browse files
committed
initial
0 parents  commit 2f890b4

15 files changed

+589
-0
lines changed

.dockerignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
data/

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.bz2 filter=lfs diff=lfs merge=lfs -text

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.py[cod]
2+
__pycache__/
3+
data/

LICENSE

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
BSD 3-Clause License
2+
3+
Copyright (c) 2017, Salesforce Research
4+
All rights reserved.
5+
6+
Redistribution and use in source and binary forms, with or without
7+
modification, are permitted provided that the following conditions are met:
8+
9+
* Redistributions of source code must retain the above copyright notice, this
10+
list of conditions and the following disclaimer.
11+
12+
* Redistributions in binary form must reproduce the above copyright notice,
13+
this list of conditions and the following disclaimer in the documentation
14+
and/or other materials provided with the distribution.
15+
16+
* Neither the name of the copyright holder nor the names of its
17+
contributors may be used to endorse or promote products derived from
18+
this software without specific prior written permission.
19+
20+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

README.md

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# WikiSQL
2+
3+
WikiSQL is a large crowd-sourced dataset for developing natural language interfaces for relational databases.
4+
5+
6+
## Citation
7+
8+
If you use WikiSQL, please cite the following work:
9+
10+
> Victor Zhong, Caiming Xiong, and Richard Socher. 2017. Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning.
11+
12+
13+
## Installation
14+
15+
Both the evaluation script as well as the dataset are stored within the repo.
16+
To track the data file, we use [Git LFS](https://git-lfs.github.com/).
17+
The installation steps are as follows:
18+
19+
```bash
20+
git clone https://github.com/MetaMind/WikiSQL
21+
cd WikiSQL
22+
pip install -r requirements.txt
23+
tar xvjf data.tar.bz2
24+
```
25+
26+
This will unpack the data files into a directory called `data`.
27+
28+
## Content and format
29+
30+
Inside the data folder you will find the files in `jsonl` and `db` format.
31+
The former can be read line by line, where each line is a serialized JSON object.
32+
The latter is a SQLite3 database.
33+
34+
### Question, query and table ID
35+
36+
These files are contained in the `*.jsonl` files. A line looks like the following:
37+
38+
```json
39+
{
40+
"phase":1,
41+
"question":"who is the manufacturer for the order year 1998?",
42+
"sql":{
43+
"conds":[
44+
[
45+
0,
46+
0,
47+
"1998"
48+
]
49+
],
50+
"sel":1,
51+
"agg":0
52+
},
53+
"table_id":"1-10007452-3"
54+
}
55+
```
56+
57+
The fields represent the following:
58+
59+
- `phase`: the phase in which the dataset was collection. We collected WikiSQL in two phases.
60+
- `question`: the natural language question written by the worker.
61+
- `table_id`: the ID of the table to which this question is addressed.
62+
- `sql`: the SQL query corresponding to the question. This has the following subfields:
63+
- `sel`: the numerical index of the column that is being selected. You can find the actual column from the table.
64+
- `agg`: the numerical index of the aggregation operator that is being used. You can find the actual operator from `Query.agg_ops` in `lib/query.py`.
65+
- `conds`: a list of triplets `(column_index, operator_index, condition)` where:
66+
- `column_index`: the numerical index of the condition column that is being used. You can find the actual column from the table.
67+
- `operator_index`: the numerical index of the condition operator that is being used. You can find the actual operator from `Query.cond_ops` in `lib/query.py`.
68+
- `condition`: the comparison value for the condition, in either `string` or `float` type.
69+
70+
### Tables
71+
72+
These files are contained in the `*.tables.jsonl` files. A line looks like the following:
73+
74+
```json
75+
{
76+
"id":"1-1000181-1",
77+
"header":[
78+
"State/territory",
79+
"Text/background colour",
80+
"Format",
81+
"Current slogan",
82+
"Current series",
83+
"Notes"
84+
],
85+
"types":[
86+
"text",
87+
"text",
88+
"text",
89+
"text",
90+
"text",
91+
"text"
92+
],
93+
"rows":[
94+
[
95+
"Australian Capital Territory",
96+
"blue/white",
97+
"Yaa\u00b7nna",
98+
"ACT \u00b7 CELEBRATION OF A CENTURY 2013",
99+
"YIL\u00b700A",
100+
"Slogan screenprinted on plate"
101+
],
102+
[
103+
"New South Wales",
104+
"black/yellow",
105+
"aa\u00b7nn\u00b7aa",
106+
"NEW SOUTH WALES",
107+
"BX\u00b799\u00b7HI",
108+
"No slogan on current series"
109+
],
110+
[
111+
"New South Wales",
112+
"black/white",
113+
"aaa\u00b7nna",
114+
"NSW",
115+
"CPX\u00b712A",
116+
"Optional white slimline series"
117+
],
118+
[
119+
"Northern Territory",
120+
"ochre/white",
121+
"Ca\u00b7nn\u00b7aa",
122+
"NT \u00b7 OUTBACK AUSTRALIA",
123+
"CB\u00b706\u00b7ZZ",
124+
"New series began in June 2011"
125+
],
126+
[
127+
"Queensland",
128+
"maroon/white",
129+
"nnn\u00b7aaa",
130+
"QUEENSLAND \u00b7 SUNSHINE STATE",
131+
"999\u00b7TLG",
132+
"Slogan embossed on plate"
133+
],
134+
[
135+
"South Australia",
136+
"black/white",
137+
"Snnn\u00b7aaa",
138+
"SOUTH AUSTRALIA",
139+
"S000\u00b7AZD",
140+
"No slogan on current series"
141+
],
142+
[
143+
"Victoria",
144+
"blue/white",
145+
"aaa\u00b7nnn",
146+
"VICTORIA - THE PLACE TO BE",
147+
"ZZZ\u00b7562",
148+
"Current series will be exhausted this year"
149+
]
150+
]
151+
}
152+
```
153+
154+
The fields represent the following:
155+
- `id`: the table ID.
156+
- `header`: a list of column names in the table.
157+
- `rows`: a list of rows. Each row is a list of row entries.
158+
159+
Tables are also contained in a corresponding `*.db` file.
160+
This is a SQL database with the same information.
161+
Note that due to the flexible format of HTML tables, the column names of tables in the database has been symbolized.
162+
For example, for a table with the columns `['foo', 'bar']`, the columns in the database are actually `col0` and `col1`.
163+
164+
## Scripts
165+
166+
`evaluate.py` contains the evaluation script, whose options are:
167+
168+
```
169+
usage: evaluate.py [-h] source_file db_file pred_file
170+
171+
positional arguments:
172+
source_file source file for the prediction
173+
db_file source database for the prediction
174+
pred_file predictions by the model
175+
176+
optional arguments:
177+
-h, --help show this help message and exit
178+
```
179+
180+
The `pred_file`, which is supplied by the user, should contain lines of serialized JSON objects.
181+
Each JSON object should contain a `query` field which corresponds to the query predicted for a line in the input `*.jsonl` file and should be similar to the `sql` field of the input.
182+
In particular, it should contain:
183+
184+
- `sel`: the numerical index of the column that is being selected. You can find the actual column from the table.
185+
- `agg`: the numerical index of the aggregation operator that is being used. You can find the actual operator from `Query.agg_ops` in `lib/query.py`.
186+
- `conds`: a list of triplets `(column_index, operator_index, condition)` where:
187+
- `column_index`: the numerical index of the condition column that is being used. You can find the actual column from the table.
188+
- `operator_index`: the numerical index of the condition operator that is being used. You can find the actual operator from `Query.cond_ops` in `lib/query.py`.
189+
- `condition`: the comparison value for the condition, in either `string` or `float` type.
190+
191+
An example predictions file can be found in `test/example.pred.dev.jsonl`.
192+
The `lib` directory contains dependencies of `evaluate.py`.
193+
194+
195+
## Integration Test
196+
197+
We supply a sample predctions file for the dev set in `test/example.pred.dev.jsonl.bz2`.
198+
You can unzip this file using `bunzip2 test/example.pred.dev.jsonl.bz2 -k` to look at what a real predictions file should look like.
199+
We distribute a docker file which installs the necessary dependencies of this library and runs the evaluation script on this file.
200+
The docker file also serves as an example of how to use the evaluation script.
201+
202+
To run the test, first build the image from the root directory:
203+
204+
```bash
205+
docker build -t wikisqltest -f test/Dockerfile .
206+
```
207+
208+
Next, run the image
209+
```bash
210+
docker run --rm --name wikisqltest wikisqltest
211+
```
212+
213+
If everything works correctly, the output should be:
214+
215+
```json
216+
{
217+
"ex_accuracy": 0.37036632039365774,
218+
"lf_accuracy": 0.2334609075997813
219+
}
220+
```

data.tar.bz2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:c2edd896d4da457e1444db2ce7beec41b78ba9d7afd69ea57f7d408c704541d3
3+
size 26363890

evaluate.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/env python
2+
import json
3+
from argparse import ArgumentParser
4+
from tqdm import tqdm
5+
from lib.dbengine import DBEngine
6+
from lib.query import Query
7+
from lib.common import count_lines
8+
9+
10+
if __name__ == '__main__':
11+
parser = ArgumentParser()
12+
parser.add_argument('source_file', help='source file for the prediction')
13+
parser.add_argument('db_file', help='source database for the prediction')
14+
parser.add_argument('pred_file', help='predictions by the model')
15+
args = parser.parse_args()
16+
17+
engine = DBEngine(args.db_file)
18+
exact_match = []
19+
with open(args.source_file) as fs, open(args.pred_file) as fp:
20+
grades = []
21+
for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):
22+
eg = json.loads(ls)
23+
ep = json.loads(lp)
24+
qg = Query.from_dict(eg['sql'])
25+
gold = engine.execute_query(eg['table_id'], qg, lower=True)
26+
pred = ep['error']
27+
qp = None
28+
if not ep['error']:
29+
try:
30+
qp = Query.from_dict(ep['query'])
31+
pred = engine.execute_query(eg['table_id'], qp, lower=True)
32+
except Exception as e:
33+
pred = repr(e)
34+
correct = pred == gold
35+
match = qp == qg
36+
grades.append(correct)
37+
exact_match.append(match)
38+
print(json.dumps({
39+
'ex_accuracy': sum(grades) / len(grades),
40+
'lf_accuracy': sum(exact_match) / len(exact_match),
41+
}, indent=2))

lib/__init__.py

Whitespace-only changes.

lib/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
def count_lines(fname):
2+
with open(fname) as f:
3+
return sum(1 for line in f)
4+
5+
6+
def detokenize(tokens):
7+
ret = ''
8+
for g, a in zip(tokens['gloss'], tokens['after']):
9+
ret += g + a
10+
return ret.strip()

lib/dbengine.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import records
2+
import re
3+
from babel.numbers import parse_decimal, NumberFormatError
4+
from lib.query import Query
5+
6+
7+
schema_re = re.compile(r'\((.+)\)')
8+
num_re = re.compile(r'[-+]?\d*\.\d+|\d+')
9+
10+
11+
class DBEngine:
12+
13+
def __init__(self, fdb):
14+
self.db = records.Database('sqlite:///{}'.format(fdb))
15+
16+
def execute_query(self, table_id, query, *args, **kwargs):
17+
return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs)
18+
19+
def execute(self, table_id, select_index, aggregation_index, conditions, lower=True):
20+
if not table_id.startswith('table'):
21+
table_id = 'table_{}'.format(table_id.replace('-', '_'))
22+
table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql
23+
schema_str = schema_re.findall(table_info)[0]
24+
schema = {}
25+
for tup in schema_str.split(', '):
26+
c, t = tup.split()
27+
schema[c] = t
28+
select = 'col{}'.format(select_index)
29+
agg = Query.agg_ops[aggregation_index]
30+
if agg:
31+
select = '{}({})'.format(agg, select)
32+
where_clause = []
33+
where_map = {}
34+
for col_index, op, val in conditions:
35+
if lower and isinstance(val, str):
36+
val = val.lower()
37+
if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
38+
try:
39+
val = float(parse_decimal(val))
40+
except NumberFormatError as e:
41+
val = float(num_re.findall(val)[0])
42+
where_clause.append('col{} {} :col{}'.format(col_index, Query.cond_ops[op], col_index))
43+
where_map['col{}'.format(col_index)] = val
44+
where_str = ''
45+
if where_clause:
46+
where_str = 'WHERE ' + ' AND '.join(where_clause)
47+
query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)
48+
out = self.db.query(query, **where_map)
49+
return [o.result for o in out]

0 commit comments

Comments
 (0)