1- from __future__ import division , print_function , unicode_literals
2-
3- import io
41import json
52import logging
3+ from pathlib import Path
64
7- from future .utils import iteritems
85from joblib import Parallel , delayed
9- from past .builtins import basestring
106
117from snips_nlu_metrics .utils .constants import (
12- AVERAGE_METRICS , CONFUSION_MATRIX , INTENTS , INTENT_UTTERANCES , METRICS ,
13- PARSING_ERRORS , UTTERANCES )
8+ AVERAGE_METRICS ,
9+ CONFUSION_MATRIX ,
10+ INTENTS ,
11+ INTENT_UTTERANCES ,
12+ METRICS ,
13+ PARSING_ERRORS ,
14+ UTTERANCES ,
15+ )
1416from snips_nlu_metrics .utils .exception import NotEnoughDataError
1517from snips_nlu_metrics .utils .metrics_utils import (
16- aggregate_matrices , aggregate_metrics , compute_average_metrics ,
17- compute_engine_metrics , compute_precision_recall_f1 , compute_split_metrics ,
18- create_shuffle_stratified_splits )
18+ aggregate_matrices ,
19+ aggregate_metrics ,
20+ compute_average_metrics ,
21+ compute_engine_metrics ,
22+ compute_precision_recall_f1 ,
23+ compute_split_metrics ,
24+ create_shuffle_stratified_splits ,
25+ )
1926
2027logger = logging .getLogger (__name__ )
2128
2229
2330def compute_cross_val_metrics (
24- dataset , engine_class , nb_folds = 5 , train_size_ratio = 1.0 ,
25- drop_entities = False , include_slot_metrics = True ,
26- slot_matching_lambda = None , progression_handler = None , num_workers = 1 ,
27- seed = None , out_of_domain_utterances = None , intents_filter = None ):
31+ dataset ,
32+ engine_class ,
33+ nb_folds = 5 ,
34+ train_size_ratio = 1.0 ,
35+ drop_entities = False ,
36+ include_slot_metrics = True ,
37+ slot_matching_lambda = None ,
38+ progression_handler = None ,
39+ num_workers = 1 ,
40+ seed = None ,
41+ out_of_domain_utterances = None ,
42+ intents_filter = None ,
43+ ):
2844 """Compute end-to-end metrics on the dataset using cross validation
2945
3046 Args:
@@ -68,14 +84,20 @@ class must inherit from `Engine`
6884 - "average_metrics": the metrics averaged over all intents
6985 """
7086
71- if isinstance (dataset , basestring ):
72- with io .open (dataset , encoding = "utf8" ) as f :
87+ if isinstance (dataset , ( str , Path ) ):
88+ with Path ( dataset ) .open (encoding = "utf8" ) as f :
7389 dataset = json .load (f )
7490
7591 try :
7692 splits = create_shuffle_stratified_splits (
77- dataset , nb_folds , train_size_ratio , drop_entities ,
78- seed , out_of_domain_utterances , intents_filter )
93+ dataset ,
94+ nb_folds ,
95+ train_size_ratio ,
96+ drop_entities ,
97+ seed ,
98+ out_of_domain_utterances ,
99+ intents_filter ,
100+ )
79101 except NotEnoughDataError as e :
80102 logger .warning ("Not enough data, skipping metrics computation: %r" , e )
81103 return {
@@ -94,8 +116,13 @@ class must inherit from `Engine`
94116 def compute_metrics (split_ ):
95117 logger .info ("Computing metrics for dataset split ..." )
96118 return compute_split_metrics (
97- engine_class , split_ , intent_list , include_slot_metrics ,
98- slot_matching_lambda , intents_filter )
119+ engine_class ,
120+ split_ ,
121+ intent_list ,
122+ include_slot_metrics ,
123+ slot_matching_lambda ,
124+ intents_filter ,
125+ )
99126
100127 effective_num_workers = min (num_workers , len (splits ))
101128 if effective_num_workers > 1 :
@@ -107,26 +134,28 @@ def compute_metrics(split_):
107134 for result in enumerate (results ):
108135 split_index , (split_metrics , errors , confusion_matrix ) = result
109136 global_metrics = aggregate_metrics (
110- global_metrics , split_metrics , include_slot_metrics )
137+ global_metrics , split_metrics , include_slot_metrics
138+ )
111139 global_confusion_matrix = aggregate_matrices (
112- global_confusion_matrix , confusion_matrix )
140+ global_confusion_matrix , confusion_matrix
141+ )
113142 global_errors += errors
114- logger .info ("Done computing %d/%d splits"
115- % (split_index + 1 , total_splits ))
143+ logger .info ("Done computing %d/%d splits" % (split_index + 1 , total_splits ))
116144
117145 if progression_handler is not None :
118- progression_handler (
119- float (split_index + 1 ) / float (total_splits ))
146+ progression_handler (float (split_index + 1 ) / float (total_splits ))
120147
121148 global_metrics = compute_precision_recall_f1 (global_metrics )
122149
123150 average_metrics = compute_average_metrics (
124151 global_metrics ,
125- ignore_none_intent = True if out_of_domain_utterances is None else False )
152+ ignore_none_intent = True if out_of_domain_utterances is None else False ,
153+ )
126154
127- nb_utterances = {intent : len (data [UTTERANCES ])
128- for intent , data in iteritems (dataset [INTENTS ])}
129- for intent , metrics in iteritems (global_metrics ):
155+ nb_utterances = {
156+ intent : len (data [UTTERANCES ]) for intent , data in dataset [INTENTS ].items ()
157+ }
158+ for intent , metrics in global_metrics .items ():
130159 metrics [INTENT_UTTERANCES ] = nb_utterances .get (intent , 0 )
131160
132161 return {
@@ -138,8 +167,13 @@ def compute_metrics(split_):
138167
139168
140169def compute_train_test_metrics (
141- train_dataset , test_dataset , engine_class , include_slot_metrics = True ,
142- slot_matching_lambda = None , intents_filter = None ):
170+ train_dataset ,
171+ test_dataset ,
172+ engine_class ,
173+ include_slot_metrics = True ,
174+ slot_matching_lambda = None ,
175+ intents_filter = None ,
176+ ):
143177 """Compute end-to-end metrics on `test_dataset` after having trained on
144178 `train_dataset`
145179
@@ -171,12 +205,12 @@ class must inherit from `Engine`
171205 - "average_metrics": the metrics averaged over all intents
172206 """
173207
174- if isinstance (train_dataset , basestring ):
175- with io .open (train_dataset , encoding = "utf8" ) as f :
208+ if isinstance (train_dataset , ( str , Path ) ):
209+ with Path ( train_dataset ) .open (encoding = "utf8" ) as f :
176210 train_dataset = json .load (f )
177211
178- if isinstance (test_dataset , basestring ):
179- with io .open (test_dataset , encoding = "utf8" ) as f :
212+ if isinstance (test_dataset , ( str , Path ) ):
213+ with Path ( test_dataset ) .open (encoding = "utf8" ) as f :
180214 test_dataset = json .load (f )
181215
182216 intent_list = set (train_dataset ["intents" ])
@@ -188,20 +222,26 @@ class must inherit from `Engine`
188222 engine .fit (train_dataset )
189223 test_utterances = [
190224 (intent_name , utterance )
191- for intent_name , intent_data in iteritems ( test_dataset [INTENTS ])
225+ for intent_name , intent_data in test_dataset [INTENTS ]. items ( )
192226 for utterance in intent_data [UTTERANCES ]
193227 if intents_filter is None or intent_name in intents_filter
194228 ]
195229
196230 logger .info ("Computing metrics..." )
197231 metrics , errors , confusion_matrix = compute_engine_metrics (
198- engine , test_utterances , intent_list , include_slot_metrics ,
199- slot_matching_lambda , intents_filter )
232+ engine ,
233+ test_utterances ,
234+ intent_list ,
235+ include_slot_metrics ,
236+ slot_matching_lambda ,
237+ intents_filter ,
238+ )
200239 metrics = compute_precision_recall_f1 (metrics )
201240 average_metrics = compute_average_metrics (metrics )
202- nb_utterances = {intent : len (data [UTTERANCES ])
203- for intent , data in iteritems (train_dataset [INTENTS ])}
204- for intent , intent_metrics in iteritems (metrics ):
241+ nb_utterances = {
242+ intent : len (data [UTTERANCES ]) for intent , data in train_dataset [INTENTS ].items ()
243+ }
244+ for intent , intent_metrics in metrics .items ():
205245 intent_metrics [INTENT_UTTERANCES ] = nb_utterances .get (intent , 0 )
206246 return {
207247 CONFUSION_MATRIX : confusion_matrix ,
0 commit comments