Skip to content

Commit 468d2f2

Browse files
committed
Update notebook too
1 parent 321f571 commit 468d2f2

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

examples/camvid_segmentation_multiclass.ipynb

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@
492492
" augmentation=get_validation_augmentation(),\n",
493493
")\n",
494494
"\n",
495-
"#Change to > 0 if not on Windows machine\n",
495+
"# Change to > 0 if not on Windows machine\n",
496496
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)\n",
497497
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
498498
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)"
@@ -545,12 +545,10 @@
545545
"import pytorch_lightning as pl\n",
546546
"import segmentation_models_pytorch as smp\n",
547547
"import torch\n",
548-
"import torch.nn.functional as F\n",
549548
"from torch.optim import lr_scheduler\n",
550549
"\n",
551550
"\n",
552551
"class CamVidModel(pl.LightningModule):\n",
553-
"\n",
554552
" def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):\n",
555553
" super().__init__()\n",
556554
" self.model = smp.create_model(\n",
@@ -591,13 +589,14 @@
591589
" mask = mask.long()\n",
592590
"\n",
593591
" # Mask shape\n",
594-
" assert mask.ndim == 3 # [batch_size, H, W]\n",
592+
" assert mask.ndim == 3 # [batch_size, H, W]\n",
595593
"\n",
596594
" # Predict mask logits\n",
597595
" logits_mask = self.forward(image)\n",
598-
" \n",
599-
" assert logits_mask.shape[1] == self.number_of_classes # [batch_size, number_of_classes, H, W]\n",
600-
" \n",
596+
"\n",
597+
" assert (\n",
598+
" logits_mask.shape[1] == self.number_of_classes\n",
599+
" ) # [batch_size, number_of_classes, H, W]\n",
601600
"\n",
602601
" # Ensure the logits mask is contiguous\n",
603602
" logits_mask = logits_mask.contiguous()\n",
@@ -1678,7 +1677,6 @@
16781677
}
16791678
],
16801679
"source": [
1681-
"import matplotlib.pyplot as plt\n",
16821680
"import numpy as np\n",
16831681
"\n",
16841682
"# Fetch a batch from the test loader\n",

0 commit comments

Comments
 (0)