|
253 | 253 | "model_log = model.fit(\n", |
254 | 254 | " X_train, Y_train,\n", |
255 | 255 | " batch_size=batch_size,\n", |
256 | | - " shuffle=False,\n", |
| 256 | + " shuffle=True,\n", |
257 | 257 | " epochs=num_epochs,\n", |
258 | 258 | " validation_data=(X_test, Y_test),\n", |
259 | 259 | " verbose=verbose)" |
|
390 | 390 | "outputs": [], |
391 | 391 | "source": [ |
392 | 392 | "if N == 2: # 2D plot of data and classification (curved) line\n", |
| 393 | + " levels = [0.0, 0.05, 0.1, 0.37, 0.5, 0.63, 0.9, 0.95, 1]\n", |
| 394 | + "\n", |
393 | 395 | " f1, f2 = np.arange(-6, 6, 0.05), np.arange(-6, 6, 0.05)\n", |
394 | 396 | " xv, yv = np.meshgrid(f1, f2)\n", |
395 | 397 | " # create data such that TF can handle it in model.predict():\n", |
|
398 | 400 | "\n", |
399 | 401 | " ygrid = model.predict(Xgrid, verbose=0) # probability 0...1\n", |
400 | 402 | "\n", |
| 403 | + " # hard decision boundary:\n", |
401 | 404 | " # ygrid = predict_class(ygrid) # binary classes {0,1}\n", |
402 | 405 | "\n", |
403 | 406 | " # reshape to plane\n", |
404 | 407 | " ygrid = np.reshape(ygrid, (xv.shape[0], xv.shape[1]))\n", |
405 | 408 | "\n", |
406 | 409 | " plt.figure(figsize=(12, 5))\n", |
407 | | - "\n", |
408 | 410 | " plt.subplot(1, 2, 1) # left plot for training data set\n", |
409 | 411 | " plt.plot(X_train[Y_train[:, 0] == 0, 0],\n", |
410 | | - " X_train[Y_train[:, 0] == 0, 1], \"C0o\", ms=1)\n", |
| 412 | + " X_train[Y_train[:, 0] == 0, 1],\n", |
| 413 | + " \"o\", color='dodgerblue', ms=1)\n", |
411 | 414 | " plt.plot(X_train[Y_train[:, 0] == 1, 0],\n", |
412 | | - " X_train[Y_train[:, 0] == 1, 1], \"C1o\", ms=1)\n", |
413 | | - " plt.contourf(f1, f2, ygrid, cmap=\"RdBu_r\")\n", |
| 415 | + " X_train[Y_train[:, 0] == 1, 1],\n", |
| 416 | + " \"o\", color='orangered', ms=1)\n", |
| 417 | + " plt.contourf(f1, f2, ygrid, cmap=\"RdBu_r\", levels=levels)\n", |
414 | 418 | " plt.colorbar()\n", |
415 | 419 | " plt.axis(\"square\")\n", |
416 | 420 | " plt.xlim(-6, 6)\n", |
|
421 | 425 | "\n", |
422 | 426 | " plt.subplot(1, 2, 2) # right plot for test data set\n", |
423 | 427 | " plt.plot(X_test[Y_test[:, 0] == 0, 0],\n", |
424 | | - " X_test[Y_test[:, 0] == 0, 1], \"C0o\", ms=1)\n", |
| 428 | + " X_test[Y_test[:, 0] == 0, 1],\n", |
| 429 | + " \"o\", color='dodgerblue', ms=1)\n", |
425 | 430 | " plt.plot(X_test[Y_test[:, 0] == 1, 0],\n", |
426 | | - " X_test[Y_test[:, 0] == 1, 1], \"C1o\", ms=1)\n", |
427 | | - " plt.contourf(f1, f2, ygrid, cmap=\"RdBu_r\")\n", |
| 431 | + " X_test[Y_test[:, 0] == 1, 1],\n", |
| 432 | + " \"o\", color='orangered', ms=1)\n", |
| 433 | + " plt.contourf(f1, f2, ygrid, cmap=\"RdBu_r\", levels=levels)\n", |
428 | 434 | " plt.colorbar()\n", |
429 | 435 | " plt.axis(\"square\")\n", |
430 | 436 | " plt.xlim(-6, 6)\n", |
|
0 commit comments