Skip to content

Commit f462ede

Browse files
committed
remove 'train' param on determine_best_model() function
1 parent 94d8c5a commit f462ede

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

emotion_recognition.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,10 @@ def grid_search(self, params, n_jobs=2, verbose=1):
199199
grid_result = grid.fit(self.X_train, self.y_train)
200200
return grid_result.best_estimator_, grid_result.best_params_, grid_result.best_score_
201201

202-
def determine_best_model(self, train=True):
202+
def determine_best_model(self):
203203
"""
204204
Loads best estimators and determine which is best for test data,
205205
and then set it to `self.model`.
206-
if `train` is True, then train that model on train data, so the model
207-
will be ready for inference.
208206
In case of regression, the metric used is MSE and accuracy for classification.
209207
Note that the execution of this method may take several minutes due
210208
to training all estimators (stored in `grid` folder) for determining the best possible one.
@@ -240,11 +238,9 @@ def determine_best_model(self, train=True):
240238
result.append((detector.model, accuracy))
241239

242240
# sort the result
243-
if self.classification:
244-
result = sorted(result, key=lambda item: item[1], reverse=True)
245-
else:
246-
# regression, best is the lower, not the higher
247-
result = sorted(result, key=lambda item: item[1], reverse=False)
241+
# regression: best is the lower, not the higher
242+
# classification: best is higher, not the lower
243+
result = sorted(result, key=lambda item: item[1], reverse=self.classification)
248244
best_estimator = result[0][0]
249245
accuracy = result[0][1]
250246
self.model = best_estimator

0 commit comments

Comments
 (0)