Skip to content

Commit 8e9b663

Browse files
committed
Update binary_logistic_regression_tf_with_hidden_layers.ipynb
1 parent 9593d4f commit 8e9b663

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

binary_logistic_regression_tf_with_hidden_layers.ipynb

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@
253253
"model_log = model.fit(\n",
254254
" X_train, Y_train,\n",
255255
" batch_size=batch_size,\n",
256-
" shuffle=False,\n",
256+
" shuffle=True,\n",
257257
" epochs=num_epochs,\n",
258258
" validation_data=(X_test, Y_test),\n",
259259
" verbose=verbose)"
@@ -390,6 +390,8 @@
390390
"outputs": [],
391391
"source": [
392392
"if N == 2: # 2D plot of data and classification (curved) line\n",
393+
" levels = [0.0, 0.05, 0.1, 0.37, 0.5, 0.63, 0.9, 0.95, 1]\n",
394+
"\n",
393395
" f1, f2 = np.arange(-6, 6, 0.05), np.arange(-6, 6, 0.05)\n",
394396
" xv, yv = np.meshgrid(f1, f2)\n",
395397
" # create data such that TF can handle it in model.predict():\n",
@@ -398,19 +400,21 @@
398400
"\n",
399401
" ygrid = model.predict(Xgrid, verbose=0) # probability 0...1\n",
400402
"\n",
403+
" # hard decision boundary:\n",
401404
" # ygrid = predict_class(ygrid) # binary classes {0,1}\n",
402405
"\n",
403406
" # reshape to plane\n",
404407
" ygrid = np.reshape(ygrid, (xv.shape[0], xv.shape[1]))\n",
405408
"\n",
406409
" plt.figure(figsize=(12, 5))\n",
407-
"\n",
408410
" plt.subplot(1, 2, 1) # left plot for training data set\n",
409411
" plt.plot(X_train[Y_train[:, 0] == 0, 0],\n",
410-
" X_train[Y_train[:, 0] == 0, 1], \"C0o\", ms=1)\n",
412+
" X_train[Y_train[:, 0] == 0, 1],\n",
413+
" \"o\", color='dodgerblue', ms=1)\n",
411414
" plt.plot(X_train[Y_train[:, 0] == 1, 0],\n",
412-
" X_train[Y_train[:, 0] == 1, 1], \"C1o\", ms=1)\n",
413-
" plt.contourf(f1, f2, ygrid, cmap=\"RdBu_r\")\n",
415+
" X_train[Y_train[:, 0] == 1, 1],\n",
416+
" \"o\", color='orangered', ms=1)\n",
417+
" plt.contourf(f1, f2, ygrid, cmap=\"RdBu_r\", levels=levels)\n",
414418
" plt.colorbar()\n",
415419
" plt.axis(\"square\")\n",
416420
" plt.xlim(-6, 6)\n",
@@ -421,10 +425,12 @@
421425
"\n",
422426
" plt.subplot(1, 2, 2) # right plot for test data set\n",
423427
" plt.plot(X_test[Y_test[:, 0] == 0, 0],\n",
424-
" X_test[Y_test[:, 0] == 0, 1], \"C0o\", ms=1)\n",
428+
" X_test[Y_test[:, 0] == 0, 1],\n",
429+
" \"o\", color='dodgerblue', ms=1)\n",
425430
" plt.plot(X_test[Y_test[:, 0] == 1, 0],\n",
426-
" X_test[Y_test[:, 0] == 1, 1], \"C1o\", ms=1)\n",
427-
" plt.contourf(f1, f2, ygrid, cmap=\"RdBu_r\")\n",
431+
" X_test[Y_test[:, 0] == 1, 1],\n",
432+
" \"o\", color='orangered', ms=1)\n",
433+
" plt.contourf(f1, f2, ygrid, cmap=\"RdBu_r\", levels=levels)\n",
428434
" plt.colorbar()\n",
429435
" plt.axis(\"square\")\n",
430436
" plt.xlim(-6, 6)\n",

0 commit comments

Comments
 (0)