|
45 | 45 | "from sklearn.datasets import make_classification\n", |
46 | 46 | "from sklearn.model_selection import train_test_split\n", |
47 | 47 | "import tensorflow as tf\n", |
48 | | - "import tensorflow.keras as keras\n", |
| 48 | + "from tensorflow import keras\n", |
49 | 49 | "\n", |
50 | 50 | "\n", |
51 | 51 | "print(\n", |
|
98 | 98 | { |
99 | 99 | "cell_type": "code", |
100 | 100 | "execution_count": null, |
101 | | - "id": "e99a940f", |
| 101 | + "id": "c795c19c-aa94-48b9-9385-390fd831a1fa", |
102 | 102 | "metadata": {}, |
103 | 103 | "outputs": [], |
104 | 104 | "source": [ |
105 | | - "def predict_class(y):\n", |
106 | | - " y[y < 0.5], y[y >= 0.5] = 0, 1" |
| 105 | + "def predict_class_tf(y):\n", |
| 106 | + " y[y[:, 0] < 0.5, :], y[y[:, 0] >= 0.5, :] = 0, 1" |
| 107 | + ] |
| 108 | + }, |
| 109 | + { |
| 110 | + "cell_type": "code", |
| 111 | + "execution_count": null, |
| 112 | + "id": "8388f542-910c-408f-9df7-b9ef4ed6f02a", |
| 113 | + "metadata": {}, |
| 114 | + "outputs": [], |
| 115 | + "source": [ |
| 116 | + "def predict_class_my(y):\n", |
| 117 | + " y[:, y[0, :] < 0.5], y[:, y[0, :] >= 0.5] = 0, 1" |
107 | 118 | ] |
108 | 119 | }, |
109 | 120 | { |
|
116 | 127 | "def evaluate(y_true, y_pred):\n", |
117 | 128 | " y_true_tmp = np.copy(y_true)\n", |
118 | 129 | " y_pred_tmp = np.copy(y_pred)\n", |
119 | | - " predict_class(y_pred_tmp)\n", |
| 130 | + " predict_class_my(y_pred_tmp)\n", |
120 | 131 | "\n", |
121 | 132 | " # https://www.tensorflow.org/api_docs/python/tf/math/confusion_matrix\n", |
122 | 133 | " # The matrix columns represent the prediction labels.\n", |
|
193 | 204 | ")\n", |
194 | 205 | "X_train, X_test, Y_train, Y_test = train_test_split(\n", |
195 | 206 | " X, Y, train_size=train_size, random_state=None\n", |
196 | | - ")" |
| 207 | + ")\n" |
197 | 208 | ] |
198 | 209 | }, |
199 | 210 | { |
|
221 | 232 | "print(\"X train dim\", X_train_our.shape, \"Y train dim\", Y_train_our.shape)" |
222 | 233 | ] |
223 | 234 | }, |
| 235 | + { |
| 236 | + "cell_type": "markdown", |
| 237 | + "id": "a3aedc53-87b0-4e3c-91a7-73be1df13b95", |
| 238 | + "metadata": {}, |
| 239 | + "source": [ |
| 240 | + "- prep for TF" |
| 241 | + ] |
| 242 | + }, |
| 243 | + { |
| 244 | + "cell_type": "code", |
| 245 | + "execution_count": null, |
| 246 | + "id": "3d4f778d-786a-4146-a662-4bbafa8577e0", |
| 247 | + "metadata": {}, |
| 248 | + "outputs": [], |
| 249 | + "source": [ |
| 250 | + "Y_train = Y_train[:, None] # newer TF needs (x,1) instead of (x) arrays\n", |
| 251 | + "Y_test = Y_test[:, None]\n", |
| 252 | + "X_train.shape, X_test.shape, Y_train.shape, Y_test.shape" |
| 253 | + ] |
| 254 | + }, |
224 | 255 | { |
225 | 256 | "cell_type": "markdown", |
226 | 257 | "id": "f2b9c6e6", |
|
439 | 470 | "source": [ |
440 | 471 | "# prediction after training finished\n", |
441 | 472 | "Y_train_pred_tf = model.predict(X_train)\n", |
442 | | - "predict_class(Y_train_pred_tf)\n", |
| 473 | + "predict_class_tf(Y_train_pred_tf)\n", |
| 474 | + "\n", |
| 475 | + "print(Y_train_pred_tf.shape, Y_train.shape)\n", |
443 | 476 | "\n", |
444 | 477 | "# confusion matrix\n", |
445 | 478 | "cm_train_tf = tf.math.confusion_matrix(\n", |
446 | | - " labels=Y_train, predictions=Y_train_pred_tf, num_classes=2\n", |
| 479 | + " labels=np.squeeze(Y_train), predictions=np.squeeze(Y_train_pred_tf), num_classes=2\n", |
447 | 480 | ")\n", |
448 | 481 | "\n", |
| 482 | + "\n", |
| 483 | + "\n", |
449 | 484 | "# get technical measures for the trained model on the training data set\n", |
450 | 485 | "results_train_tf = model.evaluate(\n", |
451 | 486 | " X_train, Y_train, batch_size=M_train, verbose=verbose\n", |
|
552 | 587 | "print(\"\\nm_test\", M_test)\n", |
553 | 588 | "# our implementation needs transposed data\n", |
554 | 589 | "X_test_our = X_test.T\n", |
555 | | - "Y_test_our = Y_test[None, :]\n", |
| 590 | + "Y_test_our = Y_test.T\n", |
556 | 591 | "print(\"X test dim\", X_test_our.shape, \"Y test dim\", Y_test_our.shape)" |
557 | 592 | ] |
558 | 593 | }, |
|
601 | 636 | "source": [ |
602 | 637 | "# prediction\n", |
603 | 638 | "Y_test_pred_tf = model.predict(X_test)\n", |
604 | | - "predict_class(Y_test_pred_tf)\n", |
| 639 | + "predict_class_tf(Y_test_pred_tf)\n", |
605 | 640 | "\n", |
606 | 641 | "# confusion matrix\n", |
607 | 642 | "cm_test_tf = tf.math.confusion_matrix(\n", |
608 | | - " labels=Y_test, predictions=Y_test_pred_tf, num_classes=2\n", |
| 643 | + " labels=np.squeeze(Y_test), predictions=np.squeeze(Y_test_pred_tf), num_classes=2\n", |
609 | 644 | ")\n", |
610 | 645 | "\n", |
611 | 646 | "# get technical measures for the trained model on the training data set\n", |
|
668 | 703 | "print(\"TF confusion matrix in %\\n\", cm_test_tf / M_test * 100.0)" |
669 | 704 | ] |
670 | 705 | }, |
| 706 | + { |
| 707 | + "cell_type": "code", |
| 708 | + "execution_count": null, |
| 709 | + "id": "e37d308d-9be7-4394-a7fd-d08a1feb279b", |
| 710 | + "metadata": {}, |
| 711 | + "outputs": [], |
| 712 | + "source": [ |
| 713 | + "X_train.shape, Y_train.shape, X_test.shape, Y_test.shape" |
| 714 | + ] |
| 715 | + }, |
671 | 716 | { |
672 | 717 | "cell_type": "code", |
673 | 718 | "execution_count": null, |
|
684 | 729 | "\n", |
685 | 730 | " plt.figure(figsize=(10, 10))\n", |
686 | 731 | " plt.subplot(2, 1, 1)\n", |
687 | | - " plt.plot(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], \"C0o\", ms=1)\n", |
688 | | - " plt.plot(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], \"C1o\", ms=1)\n", |
| 732 | + " plt.plot(X_train[Y_train[:, 0] == 0, 0], X_train[Y_train[:, 0] == 0, 1], \"C0o\", ms=1)\n", |
| 733 | + " plt.plot(X_train[Y_train[:, 0] == 1, 0], X_train[Y_train[:, 0] == 1, 1], \"C1o\", ms=1)\n", |
689 | 734 | " plt.contourf(f1, f2, tmp, cmap=\"RdBu_r\")\n", |
690 | 735 | " plt.axis(\"equal\")\n", |
691 | 736 | " plt.colorbar()\n", |
|
694 | 739 | " plt.ylabel(\"feature 2\")\n", |
695 | 740 | "\n", |
696 | 741 | " plt.subplot(2, 1, 2)\n", |
697 | | - " plt.plot(X_test[Y_test == 0, 0], X_test[Y_test == 0, 1], \"C0o\", ms=1)\n", |
698 | | - " plt.plot(X_test[Y_test == 1, 0], X_test[Y_test == 1, 1], \"C1o\", ms=1)\n", |
| 742 | + " plt.plot(X_test[Y_test[:, 0] == 0, 0], X_test[Y_test[:, 0] == 0, 1], \"C0o\", ms=1)\n", |
| 743 | + " plt.plot(X_test[Y_test[:, 0] == 1, 0], X_test[Y_test[:, 0] == 1, 1], \"C1o\", ms=1)\n", |
699 | 744 | " plt.contourf(f1, f2, tmp, cmap=\"RdBu_r\")\n", |
700 | 745 | " plt.axis(\"equal\")\n", |
701 | 746 | " plt.colorbar()\n", |
|
735 | 780 | "name": "python", |
736 | 781 | "nbconvert_exporter": "python", |
737 | 782 | "pygments_lexer": "ipython3", |
738 | | - "version": "3.10.6" |
| 783 | + "version": "3.12.3" |
739 | 784 | } |
740 | 785 | }, |
741 | 786 | "nbformat": 4, |
|
0 commit comments