33import io
44import json
55import logging
6- from builtins import map
76
87from future .utils import iteritems
98from joblib import Parallel , delayed
@@ -25,20 +24,20 @@ def compute_cross_val_metrics(
2524 dataset , engine_class , nb_folds = 5 , train_size_ratio = 1.0 ,
2625 drop_entities = False , include_slot_metrics = True ,
2726 slot_matching_lambda = None , progression_handler = None , num_workers = 1 ,
28- seed = None , out_of_domain_utterances = None ):
27+ seed = None , out_of_domain_utterances = None , intents_filter = None ):
2928 """Compute end-to-end metrics on the dataset using cross validation
3029
3130 Args:
32- dataset (dict or str): Dataset or path to dataset
33- engine_class: Python class to use for training and inference, this
31+ dataset (dict or str): dataset or path to dataset
32+ engine_class: python class to use for training and inference, this
3433 class must inherit from `Engine`
35- nb_folds (int, optional): Number of folds to use for cross validation
34+ nb_folds (int, optional): number of folds to use for cross validation
3635 (default=5)
3736 train_size_ratio (float, optional): ratio of intent utterances to use
3837 for training (default=1.0)
39- drop_entities (bool, optional): Specify whether or not all entity
38+ drop_entities (bool, optional): specify whether or not all entity
4039 values should be removed from training data (default=False)
41- include_slot_metrics (bool, optional): If false, the slots metrics and
40+ include_slot_metrics (bool, optional): if false, the slots metrics and
4241 the slots parsing errors will not be reported (default=True)
4342 slot_matching_lambda (lambda, optional):
4443 lambda expected_slot, actual_slot -> bool,
@@ -52,9 +51,13 @@ class must inherit from `Engine`
5251 num_workers (int, optional): number of workers to use. Each worker
5352 is assigned a certain number of splits (default=1)
5453 seed (int, optional): seed for the split creation
55- out_of_domain_utterances (list, optional): If defined, list of
56- out-of-domain utterances to be added to the pool of test utterances
54+ out_of_domain_utterances (list, optional): if defined, list of
55+ out-of-domain utterances to be added to the pool of test utterances
5756 in each split
57+ intents_filter (list of str, optional): if defined, at inference times
58+ test utterances will be restricted to the ones belonging to this
59+ filter. Moreover, if the parsing API allows it, the inference will
60+ be made using this intents filter.
5861
5962 Returns:
6063 dict: Metrics results containing the following data
@@ -72,7 +75,7 @@ class must inherit from `Engine`
7275 try :
7376 splits = create_shuffle_stratified_splits (
7477 dataset , nb_folds , train_size_ratio , drop_entities ,
75- seed , out_of_domain_utterances )
78+ seed , out_of_domain_utterances , intents_filter )
7679 except NotEnoughDataError as e :
7780 logger .warning ("Skipping metrics computation because of: %s"
7881 % e .message )
@@ -93,14 +96,14 @@ def compute_metrics(split_):
9396 logger .info ("Computing metrics for dataset split ..." )
9497 return compute_split_metrics (
9598 engine_class , split_ , intent_list , include_slot_metrics ,
96- slot_matching_lambda )
99+ slot_matching_lambda , intents_filter )
97100
98101 effective_num_workers = min (num_workers , len (splits ))
99102 if effective_num_workers > 1 :
100103 parallel = Parallel (n_jobs = effective_num_workers )
101104 results = parallel (delayed (compute_metrics )(split ) for split in splits )
102105 else :
103- results = map ( compute_metrics , splits )
106+ results = [ compute_metrics ( s ) for s in splits ]
104107
105108 for result in enumerate (results ):
106109 split_index , (split_metrics , errors , confusion_matrix ) = result
@@ -137,7 +140,7 @@ def compute_metrics(split_):
137140
138141def compute_train_test_metrics (
139142 train_dataset , test_dataset , engine_class , include_slot_metrics = True ,
140- slot_matching_lambda = None ):
143+ slot_matching_lambda = None , intents_filter = None ):
141144 """Compute end-to-end metrics on `test_dataset` after having trained on
142145 `train_dataset`
143146
@@ -155,6 +158,10 @@ class must inherit from `Engine`
155158 metrics, otherwise exact match will be used.
156159 `expected_slot` corresponds to the slot as defined in the dataset,
157160 and `actual_slot` corresponds to the slot as returned by the NLU
161+ intents_filter (list of str, optional): if defined, at inference times
162+ test utterances will be restricted to the ones belonging to this
163+ filter. Moreover, if the parsing API allows it, the inference will
164+ be made using this intents filter.
158165
159166 Returns
160167 dict: Metrics results containing the following data
@@ -184,12 +191,13 @@ class must inherit from `Engine`
184191 (intent_name , utterance )
185192 for intent_name , intent_data in iteritems (test_dataset [INTENTS ])
186193 for utterance in intent_data [UTTERANCES ]
194+ if intents_filter is None or intent_name in intents_filter
187195 ]
188196
189197 logger .info ("Computing metrics..." )
190198 metrics , errors , confusion_matrix = compute_engine_metrics (
191199 engine , test_utterances , intent_list , include_slot_metrics ,
192- slot_matching_lambda )
200+ slot_matching_lambda , intents_filter )
193201 metrics = compute_precision_recall_f1 (metrics )
194202 average_metrics = compute_average_metrics (metrics )
195203 nb_utterances = {intent : len (data [UTTERANCES ])
0 commit comments