|
44 | 44 | "from sklearn.datasets import make_classification\n", |
45 | 45 | "from sklearn.model_selection import train_test_split\n", |
46 | 46 | "import tensorflow as tf\n", |
47 | | - "import tensorflow.keras as keras\n", |
48 | | - "import tensorflow.keras.backend as K\n", |
| 47 | + "from tensorflow import keras\n", |
| 48 | + "\n", |
49 | 49 | "\n", |
50 | 50 | "print(\n", |
51 | 51 | " \"TF version\",\n", |
|
229 | 229 | "\n", |
230 | 230 | "model.compile(optimizer=optimizer, loss=loss, metrics=metrics)\n", |
231 | 231 | "\n", |
232 | | - "tw = np.sum([K.count_params(w) for w in model.trainable_weights])\n", |
233 | | - "print(\"\\ntrainable_weights:\", tw, \"\\n\")" |
| 232 | + "model.summary()" |
234 | 233 | ] |
235 | 234 | }, |
236 | 235 | { |
|
249 | 248 | "outputs": [], |
250 | 249 | "source": [ |
251 | 250 | "model.fit(\n", |
252 | | - " X_train, Y_train, epochs=epochs, batch_size=batch_size, verbose=verbose\n", |
| 251 | + " X_train, Y_train[:, None], epochs=epochs, batch_size=batch_size, verbose=verbose\n", |
253 | 252 | ")" |
254 | 253 | ] |
255 | 254 | }, |
|
279 | 278 | "metadata": {}, |
280 | 279 | "outputs": [], |
281 | 280 | "source": [ |
282 | | - "results = model.evaluate(X_train, Y_train, batch_size=M_train, verbose=False)\n", |
| 281 | + "results = model.evaluate(X_train, Y_train[:, None], batch_size=M_train, verbose=False)\n", |
283 | 282 | "Y_train_pred = model.predict(X_train)\n", |
284 | | - "predict_class(Y_train_pred)" |
| 283 | + "predict_class(Y_train_pred[:, None])" |
285 | 284 | ] |
286 | 285 | }, |
287 | 286 | { |
|
333 | 332 | "metadata": {}, |
334 | 333 | "outputs": [], |
335 | 334 | "source": [ |
336 | | - "results = model.evaluate(X_test, Y_test, batch_size=M_test, verbose=False)\n", |
| 335 | + "results = model.evaluate(X_test, Y_test[:,None], batch_size=M_test, verbose=False)\n", |
337 | 336 | "Y_test_pred = model.predict(X_test)\n", |
338 | 337 | "predict_class(Y_test_pred)" |
339 | 338 | ] |
|
467 | 466 | "name": "python", |
468 | 467 | "nbconvert_exporter": "python", |
469 | 468 | "pygments_lexer": "ipython3", |
470 | | - "version": "3.10.6" |
| 469 | + "version": "3.12.3" |
471 | 470 | } |
472 | 471 | }, |
473 | 472 | "nbformat": 4, |
|
0 commit comments