@@ -199,12 +199,10 @@ def grid_search(self, params, n_jobs=2, verbose=1):
199199 grid_result = grid .fit (self .X_train , self .y_train )
200200 return grid_result .best_estimator_ , grid_result .best_params_ , grid_result .best_score_
201201
202- def determine_best_model (self , train = True ):
202+ def determine_best_model (self ):
203203 """
204204 Loads best estimators and determine which is best for test data,
205205 and then set it to `self.model`.
206- if `train` is True, then train that model on train data, so the model
207- will be ready for inference.
208206 In case of regression, the metric used is MSE and accuracy for classification.
209207 Note that the execution of this method may take several minutes due
210208 to training all estimators (stored in `grid` folder) for determining the best possible one.
@@ -240,11 +238,9 @@ def determine_best_model(self, train=True):
240238 result .append ((detector .model , accuracy ))
241239
242240 # sort the result
243- if self .classification :
244- result = sorted (result , key = lambda item : item [1 ], reverse = True )
245- else :
246- # regression, best is the lower, not the higher
247- result = sorted (result , key = lambda item : item [1 ], reverse = False )
241+ # regression: best is the lower, not the higher
242+ # classification: best is higher, not the lower
243+ result = sorted (result , key = lambda item : item [1 ], reverse = self .classification )
248244 best_estimator = result [0 ][0 ]
249245 accuracy = result [0 ][1 ]
250246 self .model = best_estimator
0 commit comments