Skip to content

Commit feb11b5

Browse files
committed
add separate package for cli and update setup entry points
1 parent e80c91e commit feb11b5

File tree

6 files changed

+46
-66
lines changed

6 files changed

+46
-66
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@ These example commands are executed from within the repository folder.
2222

2323
### Training
2424
```shell script
25-
python train.py --parser lin --dir models/lin --conll ../conll2016
25+
python cli/train.py lin path/to/model path/to/conll
2626
```
27+
Training data format is json, the folder contains subfolders `en.{train,dev,test}`
28+
with files `rtelations.json` and `parses.json`.
2729

28-
### Prediction
30+
### Evaluation
2931
```shell script
30-
python parse.py -i path/to/some/textfile -m models/lin
32+
python cli/test.py lin path/to/model path/to/conll
3133
```
3234

33-
### Evaluation
35+
### Prediction
3436
```shell script
35-
python test.py --parser lin --dir models/lin --conll ../conll2016
37+
python cli/parse.py -i path/to/some/textfile -m models/lin
3638
```

cli/__init__.py

Whitespace-only changes.
File renamed without changes.
Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import os
22

3+
import click
34
from tqdm import tqdm
45

56
from discopy.parsers import get_parser
6-
from discopy.semi_utils import get_arguments
77

8-
args = get_arguments()
9-
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
8+
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES', '')
109

1110
from discopy.data.conll16 import get_conll_dataset
1211
from discopy.utils import init_logger
1312
import discopy.evaluate.exact
1413

15-
os.makedirs(args.dir, exist_ok=True)
16-
1714
logger = init_logger()
1815

1916

@@ -41,33 +38,27 @@ def evaluate_parser(pdtb_gold, pdtb_pred, threshold=0.7):
4138
return discopy.evaluate.exact.evaluate_all(gold_relations, pred_relations, threshold=threshold)
4239

4340

44-
if __name__ == '__main__':
45-
parses_train, pdtb_train = get_conll_dataset(args.conll, 'en.train', load_trees=True, connective_mapping=True)
46-
parses_val, pdtb_val = get_conll_dataset(args.conll, 'en.dev', load_trees=True, connective_mapping=True)
47-
parses_test, pdtb_test = get_conll_dataset(args.conll, 'en.test', load_trees=True, connective_mapping=True)
48-
parses_blind, pdtb_blind = get_conll_dataset(args.conll, 'en.blind-test', load_trees=True, connective_mapping=True)
49-
41+
@click.command()
42+
@click.argument('parser', type=str)
43+
@click.argument('model-path', type=str)
44+
@click.argument('conll-path', type=str)
45+
@click.option('-t', '--threshold', default=0.9, type=str)
46+
def main(parser, model_path, conll_path, threshold):
47+
parses_test, pdtb_test = get_conll_dataset(conll_path, 'en.test', load_trees=True, connective_mapping=True)
48+
parses_blind, pdtb_blind = get_conll_dataset(conll_path, 'en.blind-test', load_trees=True, connective_mapping=True)
5049
logger.info('Init Parser...')
51-
parser = get_parser(args.parser)
52-
parser_path = args.dir
53-
54-
if args.train:
55-
logger.info('Train end-to-end Parser...')
56-
parser.fit(pdtb_train, parses_train, pdtb_val, parses_val)
57-
parser.save(os.path.join(args.dir))
58-
elif os.path.exists(args.dir):
59-
logger.info('Load pre-trained Parser...')
60-
parser.load(args.dir)
61-
else:
62-
raise ValueError('Training and Loading not clear')
63-
50+
parser = get_parser(parser)
51+
logger.info('Load pre-trained Parser...')
52+
parser.load(model_path)
6453
logger.info('component evaluation (test)')
6554
parser.score(pdtb_test, parses_test)
66-
6755
logger.info('extract discourse relations from test data')
6856
pdtb_pred = extract_discourse_relations(parser, parses_test)
69-
evaluate_parser(pdtb_test, pdtb_pred, threshold=args.threshold)
70-
57+
evaluate_parser(pdtb_test, pdtb_pred, threshold=threshold)
7158
logger.info('extract discourse relations from BLIND data')
7259
pdtb_pred = extract_discourse_relations(parser, parses_blind)
73-
evaluate_parser(pdtb_blind, pdtb_pred, threshold=args.threshold)
60+
evaluate_parser(pdtb_blind, pdtb_pred, threshold=threshold)
61+
62+
63+
if __name__ == '__main__':
64+
main()
Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,15 @@
1-
import argparse
21
import os
32

4-
from discopy.data.conll16 import get_conll_dataset
5-
from discopy.parsers import get_parser
6-
7-
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
8-
3+
import click
94
from tqdm import tqdm
105

6+
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES', '')
7+
8+
from discopy.data.conll16 import get_conll_dataset
9+
from discopy.parsers import get_parser
1110
import discopy.evaluate.exact
1211
from discopy.utils import init_logger
1312

14-
argument_parser = argparse.ArgumentParser()
15-
argument_parser.add_argument("--dir", help="",
16-
default='tmp')
17-
argument_parser.add_argument("--conll", help="",
18-
default='')
19-
argument_parser.add_argument("--parser", help="",
20-
default='lin')
21-
argument_parser.add_argument("--threshold", help="",
22-
default=0.9, type=float)
23-
args = argument_parser.parse_args()
24-
25-
os.makedirs(args.dir, exist_ok=True)
26-
2713
logger = init_logger()
2814

2915

@@ -51,19 +37,19 @@ def evaluate_parser(pdtb_gold, pdtb_pred, threshold=0.7):
5137
return discopy.evaluate.exact.evaluate_all(gold_relations, pred_relations, threshold=threshold)
5238

5339

54-
if __name__ == '__main__':
55-
parses_train, pdtb_train = get_conll_dataset(args.conll, 'en.train', load_trees=True, connective_mapping=True)
56-
parses_val, pdtb_val = get_conll_dataset(args.conll, 'en.dev', load_trees=True, connective_mapping=True)
57-
parses_test, pdtb_test = get_conll_dataset(args.conll, 'en.test', load_trees=True, connective_mapping=True)
58-
parses_blind, pdtb_blind = get_conll_dataset(args.conll, 'en.blind-test', load_trees=True, connective_mapping=True)
59-
40+
@click.command()
41+
@click.argument('parser', type=str)
42+
@click.argument('model-path', type=str)
43+
@click.argument('conll-path', type=str)
44+
def main(parser, model_path, conll_path):
45+
parses_train, pdtb_train = get_conll_dataset(conll_path, 'en.train', load_trees=True, connective_mapping=True)
46+
parses_val, pdtb_val = get_conll_dataset(conll_path, 'en.dev', load_trees=True, connective_mapping=True)
6047
logger.info('Init Parser...')
61-
parser = get_parser(args.parser)
62-
48+
parser = get_parser(parser)
6349
logger.info('Train end-to-end Parser...')
6450
parser.fit(pdtb_train, parses_train, pdtb_val, parses_val)
65-
parser.save(os.path.join(args.dir))
51+
parser.save(os.path.join(model_path))
52+
6653

67-
logger.info('extract discourse relations from test data')
68-
pdtb_pred = extract_discourse_relations(parser, parses_test)
69-
all_results = evaluate_parser(pdtb_test, pdtb_pred, threshold=args.threshold)
54+
if __name__ == '__main__':
55+
main()

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
zip_safe=False,
2929
entry_points={
3030
'console_scripts': [
31-
'discopy=main:main',
32-
'discopy-parse=parse'
31+
'discopy-train=cli.train:main',
32+
'discopy-test=cli.test:main',
33+
'discopy-parse=cli.parse:main',
3334
],
3435
}
3536
)

0 commit comments

Comments
 (0)