Skip to content

Commit 2a1bd63

Browse files
committed
now balancing the dataset when one of classes has 0 samples is ignored
1 parent a5594b7 commit 2a1bd63

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

data_extractor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def _balance_data(self, partition):
169169
count.append(len([ e for e in emotions if e == emotion]))
170170
# get the minimum data samples to balance to
171171
minimum = min(count)
172+
if minimum == 0:
173+
# won't balance, otherwise 0 samples will be loaded
174+
print("[!] One class has 0 samples, setting balance to False")
175+
self.balance = False
176+
return
172177
if self.verbose:
173178
print("[*] Balancing the dataset to the minimum value:", minimum)
174179
d = defaultdict(list)
@@ -239,5 +244,6 @@ def load_data(train_desc_files, test_desc_files, audio_config=None, classificati
239244
"y_train": np.array(audiogen.train_emotions),
240245
"y_test": np.array(audiogen.test_emotions),
241246
"train_audio_paths": audiogen.train_audio_paths,
242-
"test_audio_paths": audiogen.test_audio_paths
247+
"test_audio_paths": audiogen.test_audio_paths,
248+
"balance": audiogen.balance,
243249
}

emotion_recognition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def load_data(self):
148148
self.y_test = result['y_test']
149149
self.train_audio_paths = result['train_audio_paths']
150150
self.test_audio_paths = result['test_audio_paths']
151+
self.balance = result["balance"]
151152
if self.verbose:
152153
print("[+] Data loaded")
153154
self.data_loaded = True
@@ -187,14 +188,14 @@ def predict_proba(self, audio_path):
187188
else:
188189
raise NotImplementedError("Probability prediction doesn't make sense for regression")
189190

190-
def grid_search(self, params, n_jobs=2):
191+
def grid_search(self, params, n_jobs=2, verbose=1):
191192
"""
192193
Performs GridSearchCV on `params` passed on the `self.model`
193194
And returns the tuple: (best_estimator, best_params, best_score).
194195
"""
195196
score = accuracy_score if self.classification else mean_absolute_error
196197
grid = GridSearchCV(estimator=self.model, param_grid=params, scoring=make_scorer(score),
197-
n_jobs=n_jobs, verbose=1, cv=3)
198+
n_jobs=n_jobs, verbose=verbose, cv=3)
198199
grid_result = grid.fit(self.X_train, self.y_train)
199200
return grid_result.best_estimator_, grid_result.best_params_, grid_result.best_score_
200201

0 commit comments

Comments
 (0)