|
492 | 492 | " augmentation=get_validation_augmentation(),\n",
|
493 | 493 | ")\n",
|
494 | 494 | "\n",
|
495 |
| - "#Change to > 0 if not on Windows machine\n", |
| 495 | + "# Change to > 0 if not on Windows machine\n", |
496 | 496 | "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)\n",
|
497 | 497 | "valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
|
498 | 498 | "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)"
|
|
545 | 545 | "import pytorch_lightning as pl\n",
|
546 | 546 | "import segmentation_models_pytorch as smp\n",
|
547 | 547 | "import torch\n",
|
548 |
| - "import torch.nn.functional as F\n", |
549 | 548 | "from torch.optim import lr_scheduler\n",
|
550 | 549 | "\n",
|
551 | 550 | "\n",
|
552 | 551 | "class CamVidModel(pl.LightningModule):\n",
|
553 |
| - "\n", |
554 | 552 | " def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):\n",
|
555 | 553 | " super().__init__()\n",
|
556 | 554 | " self.model = smp.create_model(\n",
|
|
591 | 589 | " mask = mask.long()\n",
|
592 | 590 | "\n",
|
593 | 591 | " # Mask shape\n",
|
594 |
| - " assert mask.ndim == 3 # [batch_size, H, W]\n", |
| 592 | + " assert mask.ndim == 3 # [batch_size, H, W]\n", |
595 | 593 | "\n",
|
596 | 594 | " # Predict mask logits\n",
|
597 | 595 | " logits_mask = self.forward(image)\n",
|
598 |
| - " \n", |
599 |
| - " assert logits_mask.shape[1] == self.number_of_classes # [batch_size, number_of_classes, H, W]\n", |
600 |
| - " \n", |
| 596 | + "\n", |
| 597 | + " assert (\n", |
| 598 | + " logits_mask.shape[1] == self.number_of_classes\n", |
| 599 | + " ) # [batch_size, number_of_classes, H, W]\n", |
601 | 600 | "\n",
|
602 | 601 | " # Ensure the logits mask is contiguous\n",
|
603 | 602 | " logits_mask = logits_mask.contiguous()\n",
|
|
1678 | 1677 | }
|
1679 | 1678 | ],
|
1680 | 1679 | "source": [
|
1681 |
| - "import matplotlib.pyplot as plt\n", |
1682 | 1680 | "import numpy as np\n",
|
1683 | 1681 | "\n",
|
1684 | 1682 | "# Fetch a batch from the test loader\n",
|
|
0 commit comments