|
81 | 81 | "\n", |
82 | 82 | "# these seeds produce 'nice' two classes each with\n", |
83 | 83 | "# two clusters for chosen M, N and train_size\n", |
84 | | - "random_state_idx = 0\n", |
| 84 | + "random_state_idx = 5\n", |
85 | 85 | "random_state = np.array([7, 21, 24, 25, 29, 33, 38])\n", |
86 | 86 | "X, Y = make_classification(\n", |
87 | 87 | " n_samples=M,\n", |
|
97 | 97 | "X_train, X_test, Y_train, Y_test = train_test_split(\n", |
98 | 98 | " X, Y, train_size=train_size, random_state=None\n", |
99 | 99 | ")\n", |
| 100 | + "Y_train = Y_train[:, None]\n", |
| 101 | + "Y_test = Y_test[:, None]\n", |
100 | 102 | "M_train = X_train.shape[0]\n", |
101 | 103 | "M_test = X_test.shape[0]\n", |
102 | 104 | "print(\"M_train\", M_train)\n", |
103 | | - "print(\"X train dim\", X_train.shape, \"Y train dim\", Y_train.shape)\n", |
| 105 | + "print(\"X train dim\", X_train.shape, \", Y train dim\", Y_train.shape)\n", |
104 | 106 | "print(\"M_test\", M_test)\n", |
105 | | - "print(\"X test dim\", X_test.shape, \"Y test dim\", Y_test.shape)" |
| 107 | + "print(\"X test dim\", X_test.shape, \", Y test dim\", Y_test.shape)" |
106 | 108 | ] |
107 | 109 | }, |
108 | 110 | { |
|
202 | 204 | "metadata": {}, |
203 | 205 | "outputs": [], |
204 | 206 | "source": [ |
205 | | - "Y_train = Y_train[:, None]\n", |
206 | | - "Y_test = Y_test[:, None]\n", |
207 | | - "\n", |
208 | | - "print(\"X train dim\", X_train.shape, \"Y train dim\", Y_train.shape)\n", |
209 | | - "print(\"X test dim\", X_test.shape, \"Y test dim\", Y_test.shape)\n", |
210 | | - "\n", |
211 | 207 | "data_train = TensorDataset(torch.FloatTensor(X_train),\n", |
212 | 208 | " torch.FloatTensor(Y_train))\n", |
| 209 | + "data_test = TensorDataset(torch.FloatTensor(X_test),\n", |
| 210 | + " torch.FloatTensor(Y_test))\n", |
213 | 211 | "data_train_loader = DataLoader(dataset=data_train,\n", |
214 | 212 | " batch_size=batch_size,\n", |
215 | 213 | " shuffle=True)" |
|
235 | 233 | " def __init__(self, input_size):\n", |
236 | 234 | " super(Model, self).__init__()\n", |
237 | 235 | "\n", |
238 | | - " self.linear1 = torch.nn.Linear(input_size, 2)\n", |
| 236 | + " self.linear1 = torch.nn.Linear(input_size, 3)\n", |
239 | 237 | " self.act1 = torch.nn.Tanh()\n", |
240 | | - " self.linear2 = torch.nn.Linear(2, 2)\n", |
| 238 | + " self.linear2 = torch.nn.Linear(3, 2)\n", |
241 | 239 | " self.act2 = torch.nn.Tanh()\n", |
242 | 240 | " self.linear3 = torch.nn.Linear(2, 1)\n", |
243 | 241 | " self.sigmoid = torch.nn.Sigmoid()\n", |
|
335 | 333 | " loss.backward() # back prop\n", |
336 | 334 | " optimizer.step() # gradient descent\n", |
337 | 335 | " optimizer.zero_grad() # reset gradients for next iter\n", |
338 | | - " # all batches per epoch, now do a prediction\n", |
339 | | - " # on train & test so check where we are:\n", |
| 336 | + " # all batches per epoch done\n", |
| 337 | + " # next, do a prediction\n", |
| 338 | + " # on train & test to check where we are:\n", |
340 | 339 | " with torch.no_grad(): # no_grad !!! to not influence the back prop\n", |
341 | | - " er = empirical_risk(model(\n", |
342 | | - " torch.tensor(X_train, dtype=torch.float32).to(device)),\n", |
343 | | - " torch.tensor(Y_train, dtype=torch.float32).to(device))\n", |
344 | | - " print('train loss', er)\n", |
345 | | - " er = empirical_risk(model(\n", |
346 | | - " torch.tensor(X_test, dtype=torch.float32).to(device)),\n", |
347 | | - " torch.tensor(Y_test, dtype=torch.float32).to(device))\n", |
348 | | - " print('test loss', er)\n", |
| 340 | + " er_train = empirical_risk(\n", |
| 341 | + " model(data_train[:][0]), # model prediction\n", |
| 342 | + " data_train[:][1]) # ground truth\n", |
| 343 | + " print('train loss', '%0.15f' % er_train)\n", |
| 344 | + " er_test = empirical_risk(\n", |
| 345 | + " model(data_test[:][0]), # model prediction\n", |
| 346 | + " data_test[:][1]) # ground truth\n", |
| 347 | + " print('test loss', '%0.15f' % er_test)\n", |
349 | 348 | " print('#####\\n')" |
350 | 349 | ] |
351 | 350 | }, |
|
372 | 371 | "metadata": {}, |
373 | 372 | "outputs": [], |
374 | 373 | "source": [ |
375 | | - "with torch.no_grad():\n", |
376 | | - "\n", |
377 | | - " er = empirical_risk(model.forward(\n", |
378 | | - " torch.tensor(X_train, dtype=torch.float32).to(device)),\n", |
379 | | - " torch.tensor(Y_train, dtype=torch.float32).to(device))\n", |
380 | | - " print('final train loss', er)\n", |
381 | | - "\n", |
382 | | - " er = empirical_risk(model.forward(\n", |
383 | | - " torch.tensor(X_test, dtype=torch.float32).to(device)),\n", |
384 | | - " torch.tensor(Y_test, dtype=torch.float32).to(device))\n", |
385 | | - " print('final test loss', er)" |
| 374 | + "with torch.no_grad(): # no_grad !!! to not influence the back prop\n", |
| 375 | + " er_train = empirical_risk(\n", |
| 376 | + " model(data_train[:][0]), # model prediction\n", |
| 377 | + " data_train[:][1]) # ground truth\n", |
| 378 | + " print('train loss', '%0.15f' % er_train)\n", |
| 379 | + " er_test = empirical_risk(\n", |
| 380 | + " model(data_test[:][0]), # model prediction\n", |
| 381 | + " data_test[:][1]) # ground truth\n", |
| 382 | + " print('test loss', '%0.15f' % er_test)" |
386 | 383 | ] |
387 | 384 | }, |
388 | 385 | { |
|
400 | 397 | "metadata": {}, |
401 | 398 | "outputs": [], |
402 | 399 | "source": [ |
| 400 | + "# from here, we work with 2D numpy arrays, i.e. (80000,1) and (20000,1)\n", |
403 | 401 | "with torch.no_grad():\n", |
404 | | - "\n", |
405 | | - " X_tmp = torch.tensor(X_train, dtype=torch.float32).to(device)\n", |
406 | | - " Y_pred_train = model.predict_class(X_tmp).cpu()\n", |
407 | | - "\n", |
408 | | - " X_tmp = torch.tensor(X_test, dtype=torch.float32).to(device)\n", |
409 | | - " Y_pred_test = model.predict_class(X_tmp).cpu()" |
| 402 | + " Y_pred_train = model.predict_class(data_train[:][0]).cpu().numpy()\n", |
| 403 | + " Y_pred_test = model.predict_class(data_test[:][0]).cpu().numpy()\n", |
| 404 | + "Y_train.shape, Y_pred_train.shape, Y_test.shape, Y_pred_test.shape" |
410 | 405 | ] |
411 | 406 | }, |
412 | 407 | { |
|
564 | 559 | " plt.ylabel(\"feature 2\")" |
565 | 560 | ] |
566 | 561 | }, |
| 562 | + { |
| 563 | + "cell_type": "markdown", |
| 564 | + "id": "82be2710", |
| 565 | + "metadata": {}, |
| 566 | + "source": [ |
| 567 | + "## Plot Decision Curves" |
| 568 | + ] |
| 569 | + }, |
| 570 | + { |
| 571 | + "cell_type": "code", |
| 572 | + "execution_count": null, |
| 573 | + "id": "c2b8382b", |
| 574 | + "metadata": {}, |
| 575 | + "outputs": [], |
| 576 | + "source": [ |
| 577 | + "plt.figure(figsize=(10, 10))\n", |
| 578 | + "N = 500\n", |
| 579 | + "predict_threshold = np.linspace(0, 1, N, endpoint=False)\n", |
| 580 | + "TPR = np.zeros_like(predict_threshold)\n", |
| 581 | + "FPR = np.zeros_like(predict_threshold)\n", |
| 582 | + "TNR = np.zeros_like(predict_threshold)\n", |
| 583 | + "FNR = np.zeros_like(predict_threshold)\n", |
| 584 | + "with torch.no_grad():\n", |
| 585 | + " Y_pred_test = model.forward(data_test[:][0]).cpu().numpy()\n", |
| 586 | + "\n", |
| 587 | + "for idx, val in enumerate(predict_threshold):\n", |
| 588 | + " Y_pred_threshold = (Y_pred_test >= val) * 1\n", |
| 589 | + " cm = confusion_matrix(Y_test, Y_pred_threshold)\n", |
| 590 | + " TN, FP = cm[0, 0], cm[0, 1]\n", |
| 591 | + " FN, TP = cm[1, 0], cm[1, 1]\n", |
| 592 | + " FPR[idx] = FP / (TN+FP) # type I error\n", |
| 593 | + " TPR[idx] = TP / (FN+TP) # recall, sensitivity, test power\n", |
| 594 | + " FNR[idx] = FN / (FN+TP) # type II error\n", |
| 595 | + " TNR[idx] = TN / (TN+FP) # specificity, selectivity\n", |
| 596 | + " if idx == N//2: # indicate 0.5 probability decision point\n", |
| 597 | + " plt.subplot(2, 2, 1)\n", |
| 598 | + " plt.text(FPR[idx], TPR[idx], '. %0.2f' % val)\n", |
| 599 | + " plt.subplot(2, 2, 2)\n", |
| 600 | + " plt.text(FPR[idx], FNR[idx], '. %0.2f' % val)\n", |
| 601 | + " plt.subplot(2, 2, 3)\n", |
| 602 | + " plt.text(FNR[idx], TNR[idx], '. %0.2f' % val)\n", |
| 603 | + " plt.subplot(2, 2, 4)\n", |
| 604 | + " plt.text(TPR[idx], TNR[idx], '. %0.2f' % val)\n", |
| 605 | + "\n", |
| 606 | + "# receiver operating characteristic (ROC) curve:\n", |
| 607 | + "plt.subplot(2, 2, 1)\n", |
| 608 | + "plt.plot(FPR, TPR, lw=2)\n", |
| 609 | + "plt.plot(0.01, 0.99, 'C3x', label='1%, 99% target'),\n", |
| 610 | + "plt.plot([0, 1], [0, 1])\n", |
| 611 | + "plt.xlabel('FPR = type I error')\n", |
| 612 | + "plt.ylabel('TPR = recall = sensitivity = power')\n", |
| 613 | + "plt.grid(True)\n", |
| 614 | + "plt.legend()\n", |
| 615 | + "plt.axis([0, 0.1, 0.9, 1])\n", |
| 616 | + "\n", |
| 617 | + "plt.subplot(2, 2, 2)\n", |
| 618 | + "plt.plot(FPR, FNR, lw=2)\n", |
| 619 | + "plt.plot(0.01, 0.01, 'C3x')\n", |
| 620 | + "plt.plot([0, 1], [1, 0])\n", |
| 621 | + "plt.xlabel('FPR = type I error')\n", |
| 622 | + "plt.ylabel('FNR = type II error')\n", |
| 623 | + "plt.grid(True)\n", |
| 624 | + "plt.axis([0, 0.1, 0, 0.1])\n", |
| 625 | + "\n", |
| 626 | + "plt.subplot(2, 2, 3)\n", |
| 627 | + "plt.plot(FNR, TNR, lw=2)\n", |
| 628 | + "plt.plot(0.01, 0.99, 'C3x')\n", |
| 629 | + "plt.plot([0, 1], [0, 1])\n", |
| 630 | + "plt.xlabel('FNR = type II error')\n", |
| 631 | + "plt.ylabel('TNR = specificity = selectivity')\n", |
| 632 | + "plt.grid(True)\n", |
| 633 | + "plt.axis([0, 0.1, 0.9, 1])\n", |
| 634 | + "\n", |
| 635 | + "plt.subplot(2, 2, 4)\n", |
| 636 | + "plt.plot(TPR, TNR, lw=2)\n", |
| 637 | + "plt.plot(0.99, 0.99, 'C3x')\n", |
| 638 | + "plt.plot([0, 1], [1, 0])\n", |
| 639 | + "plt.xlabel('TPR = recall = sensitivity = power')\n", |
| 640 | + "plt.ylabel('TNR = specificity = selectivity')\n", |
| 641 | + "plt.grid(True)\n", |
| 642 | + "plt.axis([0.9, 1, 0.9, 1])" |
| 643 | + ] |
| 644 | + }, |
| 645 | + { |
| 646 | + "cell_type": "markdown", |
| 647 | + "id": "69e3213d", |
| 648 | + "metadata": {}, |
| 649 | + "source": [ |
| 650 | + "Using the typical 50% decision boundary, the trained model almost ideally performs with 1% / 99% power on the unseen test data. " |
| 651 | + ] |
| 652 | + }, |
567 | 653 | { |
568 | 654 | "cell_type": "markdown", |
569 | 655 | "id": "43218100", |
|
0 commit comments