|
1 | 1 | import os |
2 | 2 |
|
| 3 | +import click |
3 | 4 | from tqdm import tqdm |
4 | 5 |
|
5 | 6 | from discopy.parsers import get_parser |
6 | | -from discopy.semi_utils import get_arguments |
7 | 7 |
|
8 | | -args = get_arguments() |
9 | | -os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu |
| 8 | +os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES', '') |
10 | 9 |
|
11 | 10 | from discopy.data.conll16 import get_conll_dataset |
12 | 11 | from discopy.utils import init_logger |
13 | 12 | import discopy.evaluate.exact |
14 | 13 |
|
15 | | -os.makedirs(args.dir, exist_ok=True) |
16 | | - |
17 | 14 | logger = init_logger() |
18 | 15 |
|
19 | 16 |
|
@@ -41,33 +38,27 @@ def evaluate_parser(pdtb_gold, pdtb_pred, threshold=0.7): |
41 | 38 | return discopy.evaluate.exact.evaluate_all(gold_relations, pred_relations, threshold=threshold) |
42 | 39 |
|
43 | 40 |
|
44 | | -if __name__ == '__main__': |
45 | | - parses_train, pdtb_train = get_conll_dataset(args.conll, 'en.train', load_trees=True, connective_mapping=True) |
46 | | - parses_val, pdtb_val = get_conll_dataset(args.conll, 'en.dev', load_trees=True, connective_mapping=True) |
47 | | - parses_test, pdtb_test = get_conll_dataset(args.conll, 'en.test', load_trees=True, connective_mapping=True) |
48 | | - parses_blind, pdtb_blind = get_conll_dataset(args.conll, 'en.blind-test', load_trees=True, connective_mapping=True) |
49 | | - |
| 41 | +@click.command() |
| 42 | +@click.argument('parser', type=str) |
| 43 | +@click.argument('model-path', type=str) |
| 44 | +@click.argument('conll-path', type=str) |
| 45 | +@click.option('-t', '--threshold', default=0.9, type=str) |
| 46 | +def main(parser, model_path, conll_path, threshold): |
| 47 | + parses_test, pdtb_test = get_conll_dataset(conll_path, 'en.test', load_trees=True, connective_mapping=True) |
| 48 | + parses_blind, pdtb_blind = get_conll_dataset(conll_path, 'en.blind-test', load_trees=True, connective_mapping=True) |
50 | 49 | logger.info('Init Parser...') |
51 | | - parser = get_parser(args.parser) |
52 | | - parser_path = args.dir |
53 | | - |
54 | | - if args.train: |
55 | | - logger.info('Train end-to-end Parser...') |
56 | | - parser.fit(pdtb_train, parses_train, pdtb_val, parses_val) |
57 | | - parser.save(os.path.join(args.dir)) |
58 | | - elif os.path.exists(args.dir): |
59 | | - logger.info('Load pre-trained Parser...') |
60 | | - parser.load(args.dir) |
61 | | - else: |
62 | | - raise ValueError('Training and Loading not clear') |
63 | | - |
| 50 | + parser = get_parser(parser) |
| 51 | + logger.info('Load pre-trained Parser...') |
| 52 | + parser.load(model_path) |
64 | 53 | logger.info('component evaluation (test)') |
65 | 54 | parser.score(pdtb_test, parses_test) |
66 | | - |
67 | 55 | logger.info('extract discourse relations from test data') |
68 | 56 | pdtb_pred = extract_discourse_relations(parser, parses_test) |
69 | | - evaluate_parser(pdtb_test, pdtb_pred, threshold=args.threshold) |
70 | | - |
| 57 | + evaluate_parser(pdtb_test, pdtb_pred, threshold=threshold) |
71 | 58 | logger.info('extract discourse relations from BLIND data') |
72 | 59 | pdtb_pred = extract_discourse_relations(parser, parses_blind) |
73 | | - evaluate_parser(pdtb_blind, pdtb_pred, threshold=args.threshold) |
| 60 | + evaluate_parser(pdtb_blind, pdtb_pred, threshold=threshold) |
| 61 | + |
| 62 | + |
| 63 | +if __name__ == '__main__': |
| 64 | + main() |
0 commit comments