22
33import io
44import json
5+ import logging
56from builtins import map
67
78from future .utils import iteritems
9+ from joblib import Parallel , delayed
810from past .builtins import basestring
9- from pathos .multiprocessing import Pool
1011
1112from snips_nlu_metrics .utils .constants import (
1213 AVERAGE_METRICS , CONFUSION_MATRIX , INTENTS , INTENT_UTTERANCES , METRICS ,
1718 compute_engine_metrics , compute_precision_recall_f1 , compute_split_metrics ,
1819 create_shuffle_stratified_splits )
1920
21+ logger = logging .getLogger (__name__ )
22+
2023
2124def compute_cross_val_metrics (
2225 dataset , engine_class , nb_folds = 5 , train_size_ratio = 1.0 ,
2326 drop_entities = False , include_slot_metrics = True ,
2427 slot_matching_lambda = None , progression_handler = None , num_workers = 1 ,
25- seed = None ):
28+ seed = None , out_of_domain_utterances = None ):
2629 """Compute end-to-end metrics on the dataset using cross validation
2730
2831 Args:
@@ -49,13 +52,17 @@ class must inherit from `Engine`
4952 num_workers (int, optional): number of workers to use. Each worker
5053 is assigned a certain number of splits (default=1)
5154 seed (int, optional): seed for the split creation
55+ out_of_domain_utterances (list, optional): If defined, list of
56+ out-of-domain utterances to be added to the pool of test utterances
57+ in each split
5258
5359 Returns:
5460 dict: Metrics results containing the following data
55-
61+
5662 - "metrics": the computed metrics
5763 - "parsing_errors": the list of parsing errors
58-
64+ - "confusion_matrix": the computed confusion matrix
65+ - "average_metrics": the metrics averaged over all intents
5966 """
6067
6168 if isinstance (dataset , basestring ):
@@ -64,9 +71,11 @@ class must inherit from `Engine`
6471
6572 try :
6673 splits = create_shuffle_stratified_splits (
67- dataset , nb_folds , train_size_ratio , drop_entities , seed )
74+ dataset , nb_folds , train_size_ratio , drop_entities ,
75+ seed , out_of_domain_utterances )
6876 except NotEnoughDataError as e :
69- print ("Skipping metrics computation because of: %s" % e .message )
77+ logger .warning ("Skipping metrics computation because of: %s"
78+ % e .message )
7079 return {
7180 AVERAGE_METRICS : None ,
7281 CONFUSION_MATRIX : None ,
@@ -80,33 +89,38 @@ class must inherit from `Engine`
8089 global_errors = []
8190 total_splits = len (splits )
8291
83- if num_workers > 1 :
84- effective_num_workers = min (num_workers , len (splits ))
85- pool = Pool (effective_num_workers )
86- runner = pool .imap_unordered
87- else :
88- runner = map
92+ def compute_metrics (split_ ):
93+ logger .info ("Computing metrics for dataset split ..." )
94+ return compute_split_metrics (
95+ engine_class , split_ , intent_list , include_slot_metrics ,
96+ slot_matching_lambda )
8997
90- results = runner (
91- lambda split :
92- compute_split_metrics (engine_class , split , intent_list ,
93- include_slot_metrics , slot_matching_lambda ),
94- splits )
98+ effective_num_workers = min (num_workers , len (splits ))
99+ if effective_num_workers > 1 :
100+ parallel = Parallel (n_jobs = effective_num_workers )
101+ results = parallel (delayed (compute_metrics )(split ) for split in splits )
102+ else :
103+ results = map (compute_metrics , splits )
95104
96- for split_index , ( split_metrics , errors , confusion_matrix ) in \
97- enumerate ( results ):
105+ for result in enumerate ( results ):
106+ split_index , ( split_metrics , errors , confusion_matrix ) = result
98107 global_metrics = aggregate_metrics (
99108 global_metrics , split_metrics , include_slot_metrics )
100109 global_confusion_matrix = aggregate_matrices (
101110 global_confusion_matrix , confusion_matrix )
102111 global_errors += errors
112+ logger .info ("Done computing %d/%d splits"
113+ % (split_index + 1 , total_splits ))
103114
104115 if progression_handler is not None :
105116 progression_handler (
106117 float (split_index + 1 ) / float (total_splits ))
107118
108119 global_metrics = compute_precision_recall_f1 (global_metrics )
109- average_metrics = compute_average_metrics (global_metrics )
120+
121+ average_metrics = compute_average_metrics (
122+ global_metrics ,
123+ ignore_none_intent = True if out_of_domain_utterances is None else False )
110124
111125 nb_utterances = {intent : len (data [UTTERANCES ])
112126 for intent , data in iteritems (dataset [INTENTS ])}
@@ -147,6 +161,8 @@ class must inherit from `Engine`
147161
148162 - "metrics": the computed metrics
149163 - "parsing_errors": the list of parsing errors
164+ - "confusion_matrix": the computed confusion matrix
165+ - "average_metrics": the metrics averaged over all intents
150166 """
151167
152168 if isinstance (train_dataset , basestring ):
@@ -161,13 +177,16 @@ class must inherit from `Engine`
161177 intent_list .update (test_dataset ["intents" ])
162178 intent_list = sorted (intent_list )
163179
180+ logger .info ("Training engine..." )
164181 engine = engine_class ()
165182 engine .fit (train_dataset )
166183 test_utterances = [
167184 (intent_name , utterance )
168185 for intent_name , intent_data in iteritems (test_dataset [INTENTS ])
169186 for utterance in intent_data [UTTERANCES ]
170187 ]
188+
189+ logger .info ("Computing metrics..." )
171190 metrics , errors , confusion_matrix = compute_engine_metrics (
172191 engine , test_utterances , intent_list , include_slot_metrics ,
173192 slot_matching_lambda )
0 commit comments