2424 linear_regression ,
2525 pme ,
2626 tf_micro ,
27- nnom
27+ nnom ,
2828)
2929from django .core .exceptions import ValidationError
3030
3535 "bonsai" ,
3636 "pme" ,
3737 "linear_regression" ,
38- "nnom"
38+ "nnom" ,
3939]
4040
4141CLASSIFER_MAP = {
4242 "decision tree ensemble" : "decision_tree_ensemble" ,
4343 "tensorflow lite for microcontrollers" : "tf_micro" ,
44+ "Neural Network" : "tf_micro" ,
4445 "nnom" : "nnom" ,
4546 "pme" : "pme" ,
4647 "boosted tree ensemble" : "boosted_tree_ensemble" ,
@@ -61,7 +62,7 @@ def get_classifier_type(model_configuration):
6162 return classifier_type .lower ()
6263
6364
64- #TODO: Make this an interface that returns the object instead of having all of these if statements
65+ # TODO: Make this an interface that returns the object instead of having all of these if statements
6566class ModelGen :
6667 @staticmethod
6768 def create_classifier_structures (classifier_type , kb_models ):
@@ -82,10 +83,10 @@ def create_classifier_structures(classifier_type, kb_models):
8283
8384 if classifier_type == "linear_regression" :
8485 return linear_regression .create_classifier_structures (kb_models )
85-
86+
8687 if classifier_type == "nnom" :
8788 return nnom .create_classifier_structures (kb_models )
88-
89+
8990 return ""
9091
9192 @staticmethod
@@ -107,7 +108,7 @@ def create_max_tmp_parameters(classifier_type, kb_models):
107108
108109 if classifier_type == "linear_regression" :
109110 return linear_regression .create_max_tmp_parameters (kb_models )
110-
111+
111112 if classifier_type == "nnom" :
112113 return nnom .create_max_tmp_parameters (kb_models )
113114
@@ -151,7 +152,7 @@ def validate_model_parameters(model_parameters, model_configuration):
151152
152153 if classifier_type == "linear_regression" :
153154 return linear_regression .validate_model_parameters (model_parameters )
154-
155+
155156 if classifier_type == "nnom" :
156157 return nnom .validate_model_parameters (model_parameters )
157158
@@ -180,8 +181,7 @@ def validate_model_configuration(model_configuration):
180181
181182 if classifier_type == "linear_regression" :
182183 return linear_regression .validate_model_configuration (model_configuration )
183-
184-
184+
185185 if classifier_type == "nnom" :
186186 return nnom .validate_model_configuration (model_configuration )
187187
@@ -204,7 +204,7 @@ def get_output_tensor_size(classifier_type, model):
204204
205205 if classifier_type == "linear_regression" :
206206 return linear_regression .get_output_tensor_size (model )
207-
207+
208208 if classifier_type == "nnom" :
209209 return nnom .get_output_tensor_size (model )
210210
@@ -232,7 +232,7 @@ def get_input_feature_type(model):
232232
233233 if classifier_type == "linear_regression" :
234234 return FLOAT
235-
235+
236236 if classifier_type == "nnom" :
237237 return UINT8_T
238238
@@ -263,7 +263,7 @@ def get_input_feature_def(model):
263263
264264 if classifier_type == "nnom" :
265265 return UINT8_T
266-
266+
267267 raise ValueError ("No classifier type found" )
268268
269269 @staticmethod
@@ -273,7 +273,7 @@ def get_model_type(model):
273273 CLASSIFICATION = 1
274274 if classifier_type == "tf_micro" :
275275 return CLASSIFICATION
276-
276+
277277 if classifier_type == "nnom" :
278278 return CLASSIFICATION
279279
0 commit comments