@@ -40,44 +40,46 @@ class DeepEmotionRecognizer(EmotionRecognizer):
4040 """
4141 The Deep Learning version of the Emotion Recognizer.
4242 This class uses RNN (LSTM, GRU, etc.) and Dense layers.
43+ #TODO add CNNs
4344 """
4445 def __init__ (self , ** kwargs ):
4546 """
4647 params:
4748 emotions (list): list of emotions to be used. Note that these emotions must be available in
4849 RAVDESS_TESS & EMODB Datasets, available nine emotions are the following:
4950 'neutral', 'calm', 'happy', 'sad', 'angry', 'fear', 'disgust', 'ps' ( pleasant surprised ), 'boredom'.
50- tess_ravdess (bool): whether to use TESS & RAVDESS Speech datasets, default is True
51- emodb (bool): whether to use EMO-DB Speech dataset, default is True,
51+ Default is ["sad", "neutral", "happy"].
52+ tess_ravdess (bool): whether to use TESS & RAVDESS Speech datasets, default is True.
53+ emodb (bool): whether to use EMO-DB Speech dataset, default is True.
5254 custom_db (bool): whether to use custom Speech dataset that is located in `data/train-custom`
53- and `data/test-custom`, default is True
54- tess_ravdess_name (str): the name of the output CSV file for TESS&RAVDESS dataset, default is "tess_ravdess.csv"
55- emodb_name (str): the name of the output CSV file for EMO-DB dataset, default is "emodb.csv"
56- custom_db_name (str): the name of the output CSV file for the custom dataset, default is "custom.csv"
55+ and `data/test-custom`, default is True.
56+ tess_ravdess_name (str): the name of the output CSV file for TESS&RAVDESS dataset, default is "tess_ravdess.csv".
57+ emodb_name (str): the name of the output CSV file for EMO-DB dataset, default is "emodb.csv".
58+ custom_db_name (str): the name of the output CSV file for the custom dataset, default is "custom.csv".
5759 features (list): list of speech features to use, default is ["mfcc", "chroma", "mel"]
58- (i.e MFCC, Chroma and MEL spectrogram )
59- classification (bool): whether to use classification or regression, default is True
60- balance (bool): whether to balance the dataset ( both training and testing ), default is True
61- verbose (bool/int): whether to print messages on certain tasks
60+ (i.e MFCC, Chroma and MEL spectrogram ).
61+ classification (bool): whether to use classification or regression, default is True.
62+ balance (bool): whether to balance the dataset ( both training and testing ), default is True.
63+ verbose (bool/int): whether to print messages on certain tasks.
6264 ==========================================================
6365 Model params
64- n_rnn_layers (int): number of RNN layers, default is 2
65- cell (keras.layers.RNN instance): RNN cell used to train the model, default is LSTM
66- rnn_units (int): number of units of `cell`, default is 128
67- n_dense_layers (int): number of Dense layers, default is 2
68- dense_units (int): number of units of the Dense layers, default is 128
66+ n_rnn_layers (int): number of RNN layers, default is 2.
67+ cell (keras.layers.RNN instance): RNN cell used to train the model, default is LSTM.
68+ rnn_units (int): number of units of `cell`, default is 128.
69+ n_dense_layers (int): number of Dense layers, default is 2.
70+ dense_units (int): number of units of the Dense layers, default is 128.
6971 dropout (list/float): dropout rate,
70- - if list, it indicates the dropout rate of each layer
71- - if float, it indicates the dropout rate for all layers
72- default is 0.3
72+ - if list, it indicates the dropout rate of each layer.
73+ - if float, it indicates the dropout rate for all layers.
74+ Default is 0.3.
7375 ==========================================================
7476 Training params
75- batch_size (int): number of samples per gradient update, default is 64
76- epochs (int): number of epochs, default is 1000
77- optimizer (str/keras.optimizers.Optimizer instance): optimizer used to train, default is "adam"
78- loss (str, callback from keras.losses): loss function that is used to minimize during training,
77+ batch_size (int): number of samples per gradient update, default is 64.
78+ epochs (int): number of epochs, default is 1000.
79+ optimizer (str/keras.optimizers.Optimizer instance): optimizer used to train, default is "adam".
80+ loss (str/ callback from keras.losses): loss function that is used to minimize during training,
7981 default is "categorical_crossentropy" for classification and "mean_squared_error" for
80- regression
82+ regression.
8183 """
8284 # init EmotionRecognizer
8385 super ().__init__ (None , ** kwargs )
@@ -117,6 +119,12 @@ def __init__(self, **kwargs):
117119 self .model_created = False
118120
119121 def _update_model_name (self ):
122+ """
123+ Generates a unique model name based on parameters passed and put it on `self.model_name`.
124+ This is used when saving the model.
125+ """
126+ # get first letters of emotions, for instance:
127+ # ["sad", "neutral", "happy"] => 'HNS' (sorted alphabetically)
120128 emotions_str = get_first_letters (self .emotions )
121129 # 'c' for classification & 'r' for regression
122130 problem_type = 'c' if self .classification else 'r'
@@ -128,15 +136,19 @@ def _get_model_filename(self):
128136 return f"results/{ self .model_name } "
129137
130138 def _model_exists (self ):
131- """Checks if model already exists in disk, returns the filename,
132- returns `None` otherwise"""
139+ """
140+ Checks if model already exists in disk, returns the filename,
141+ and returns `None` otherwise.
142+ """
133143 filename = self ._get_model_filename ()
134144 return filename if os .path .isfile (filename ) else None
135145
136146 def _compute_input_length (self ):
147+ """
148+ Calculates the input shape to be able to construct the model.
149+ """
137150 if not self .data_loaded :
138151 self .load_data ()
139-
140152 self .input_length = self .X_train [0 ].shape [1 ]
141153
142154 def _verify_emotions (self ):
@@ -146,9 +158,8 @@ def _verify_emotions(self):
146158
147159 def create_model (self ):
148160 """
149- Constructs the neural network
161+ Constructs the neural network based on parameters passed.
150162 """
151-
152163 if self .model_created :
153164 # model already created, why call twice
154165 return
@@ -196,17 +207,23 @@ def create_model(self):
196207 print ("[+] Model created" )
197208
198209 def load_data (self ):
210+ """
211+ Loads and extracts features from the audio files for the db's specified.
212+ And then reshapes the data.
213+ """
199214 super ().load_data ()
200- # reshape to 3 dims
215+ # reshape X's to 3 dims
201216 X_train_shape = self .X_train .shape
202217 X_test_shape = self .X_test .shape
203218 self .X_train = self .X_train .reshape ((1 , X_train_shape [0 ], X_train_shape [1 ]))
204219 self .X_test = self .X_test .reshape ((1 , X_test_shape [0 ], X_test_shape [1 ]))
205220
206221 if self .classification :
222+ # one-hot encode when its classification
207223 self .y_train = to_categorical ([ self .emotions2int [str (e )] for e in self .y_train ])
208224 self .y_test = to_categorical ([ self .emotions2int [str (e )] for e in self .y_test ])
209225
226+ # reshape labels
210227 y_train_shape = self .y_train .shape
211228 y_test_shape = self .y_test .shape
212229 if self .classification :
@@ -217,7 +234,12 @@ def load_data(self):
217234 self .y_test = self .y_test .reshape ((1 , y_test_shape [0 ], 1 ))
218235
219236 def train (self , override = False ):
220-
237+ """
238+ Trains the neural network.
239+ Params:
240+ override (bool): whether to override the previous identical model, can be used
241+ when you changed the dataset, default is False
242+ """
221243 # if model isn't created yet, create it
222244 if not self .model_created :
223245 self .create_model ()
@@ -262,6 +284,19 @@ def predict(self, audio_path):
262284 else :
263285 return self .model .predict (feature )[0 ][0 ][0 ]
264286
287+ def predict_proba (self , audio_path ):
288+ if self .classification :
289+ feature = extract_feature (audio_path , ** self .audio_config ).reshape ((1 , 1 , self .input_length ))
290+ proba = self .model .predict (feature )[0 ][0 ]
291+ result = {}
292+ for prob , emotion in zip (proba , self .emotions ):
293+ result [emotion ] = prob
294+ return result
295+ else :
296+ raise NotImplementedError ("Probability prediction doesn't make sense for regression" )
297+
298+
299+
265300 def test_score (self ):
266301 y_test = self .y_test [0 ]
267302 if self .classification :
0 commit comments