diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/dataset.py b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/dataset.py new file mode 100644 index 000000000..0d7f5f9c4 --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/dataset.py @@ -0,0 +1,20 @@ +import numpy as np +import glob +import imageio + +def load_oasis_data(image_dir, label_dir): + # Sort image and label files to ensure corresponding order + image_files = sorted(glob.glob(f"{image_dir}/*.nii.png")) + label_files = sorted(glob.glob(f"{label_dir}/*.nii.png")) + + images = [] + labels = [] + for img_file, lbl_file in zip(image_files, label_files): + image = imageio.imread(img_file) + label = imageio.imread(lbl_file) + images.append(image) + labels.append(label) + + images = np.stack(images, axis=0) + labels = np.stack(labels, axis=0) + return images, labels diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/loss_plot.png b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/loss_plot.png new file mode 100644 index 000000000..73645a231 Binary files /dev/null and b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/loss_plot.png differ diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/modules.py b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/modules.py new file mode 100644 index 000000000..f41ebeef9 --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/modules.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn + +class DoubleConv(nn.Module): + """(Conv => BN => ReLU) * 2""" + def __init__(self, in_channels, out_channels, mid_channels=None, dropout=0.2): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.double_conv(x) + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + def __init__(self, in_channels, out_channels, dropout=0.2): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels, dropout=dropout) + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class Up(nn.Module): + """Upscaling then double conv""" + def __init__(self, in_channels, out_channels, bilinear=True, dropout=0.2): + super().__init__() + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, dropout=dropout) + else: + self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels, dropout=dropout) + + def forward(self, x1, x2): + x1 = self.up(x1) + # Pad x1 if necessary to match x2's size + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) + +class ImprovedUNet(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=True, dropout=0.2): + super().__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = DoubleConv(n_channels, 64, dropout=dropout) + self.down1 = Down(64, 128, dropout=dropout) + self.down2 = Down(128, 256, dropout=dropout) + self.down3 = Down(256, 512, dropout=dropout) + factor = 2 if bilinear else 1 + self.down4 = Down(512, 1024 // factor, dropout=dropout) + self.up1 = Up(1024, 512 // factor, bilinear, dropout=dropout) + self.up2 = Up(512, 256 // factor, bilinear, dropout=dropout) + self.up3 = Up(256, 128 // factor, bilinear, dropout=dropout) + self.up4 = Up(128, 64, bilinear, dropout=dropout) + self.outc = OutConv(64, n_classes) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + return logits diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/predict.py b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/predict.py new file mode 100644 index 000000000..dcc5490eb --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/predict.py @@ -0,0 +1,54 @@ +import torch +import matplotlib.pyplot as plt +from modules import ImprovedUNet +from dataset import load_oasis_data +from torch.utils.data import DataLoader, TensorDataset + +def dice_coefficient(pred, target, epsilon=1e-6): + pred = pred.int() + target = target.int() + intersection = (pred & target).sum(dim=(1,2,3)) + union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) + dice = (2 * intersection + epsilon) / (union + epsilon) + return dice.mean().item() + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = ImprovedUNet(n_channels=1, n_classes=1).to(device) +model.load_state_dict(torch.load('unet_epoch15.pth', map_location=device)) +model.eval() + +test_images, test_labels = load_oasis_data( + '/home/groups/comp3710/OASIS/keras_png_slices_test', + '/home/groups/comp3710/OASIS/keras_png_slices_seg_test' +) + +images_tensor = torch.tensor(test_images, dtype=torch.float32).unsqueeze(1)/255.0 +labels_tensor = torch.tensor(test_labels, dtype=torch.float32).unsqueeze(1)/255.0 + +# Use DataLoader for batching +test_ds = TensorDataset(images_tensor, labels_tensor) +test_dl = DataLoader(test_ds, batch_size=8, shuffle=False) +preds = [] + +with torch.no_grad(): + for xb, _ in test_dl: + out = model(xb.to(device)) + out = torch.sigmoid(out) > 0.5 + preds.append(out.cpu()) +preds = torch.cat(preds, dim=0) + +# Compute Dice coefficient +dice = dice_coefficient(preds, labels_tensor) +print(f"Dice coefficient on test set: {dice:.4f}") + +# Visualize 3 sample predictions +for i in range(3): + fig, axs = plt.subplots(1, 3, figsize=(12, 4)) + axs[0].imshow(test_images[i], cmap='gray') + axs[0].set_title('Input') + axs[1].imshow(test_labels[i], cmap='gray') + axs[1].set_title('Ground Truth') + axs[2].imshow(preds[i][0], cmap='gray') + axs[2].set_title('Prediction') + plt.savefig(f'prediction_{i}.png') + plt.close(fig) diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_0.png b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_0.png new file mode 100644 index 000000000..8ef7323f0 Binary files /dev/null and b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_0.png differ diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_1.png b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_1.png new file mode 100644 index 000000000..ae9969bf5 Binary files /dev/null and b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_1.png differ diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_2.png b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_2.png new file mode 100644 index 000000000..e5d44f430 Binary files /dev/null and b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/prediction_2.png differ diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/load_dataset.py b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/load_dataset.py new file mode 100644 index 000000000..048d28dc3 --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/load_dataset.py @@ -0,0 +1,131 @@ +""" +MRI Dataset Loading and Preprocessing Module + +PURPOSE: This file handles the loading and preprocessing of MRI brain scan images for training a VAE. +It creates a custom PyTorch Dataset class that can load PNG images from directories, apply +transforms (resizing, normalization), and create DataLoaders for efficient batch processing. + +WHY IT'S NEEDED: +- Raw MRI images need to be preprocessed (resized to consistent dimensions, normalized) +- PyTorch requires a custom Dataset class to work with our file structure +- DataLoaders enable efficient batch processing during training +- Visualisation helps verify the data is loaded correctly +""" + +import os +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import matplotlib.pyplot as plt + +# ----------------------------- +# Custom Dataset Class for MRI Images +# ----------------------------- +class MRIDataset(Dataset): + """ + Custom PyTorch Dataset class for loading MRI brain scan images. + + This class inherits from PyTorch's Dataset and implements the required methods + to load PNG images from a directory, apply transforms, and return them as tensors. + """ + + def __init__(self, folder, transform=None): + """ + Initialize the dataset with a folder containing PNG images. + + Args: + folder (str): Path to directory containing PNG images + transform (callable, optional): Transform to apply to each image + """ + self.folder = folder + self.transform = transform + # Get all PNG files from the folder and sort them for consistent ordering + self.images = sorted([os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".png")]) + + # Ensure we have images to work with + if len(self.images) == 0: + raise ValueError(f"No PNG images found in folder: {folder}") + + def __len__(self): + """Return the number of images in the dataset.""" + return len(self.images) + + def __getitem__(self, idx): + """ + Load and return a single image at the given index. + + Args: + idx (int): Index of the image to load + + Returns: + torch.Tensor: Preprocessed image tensor + """ + # Load image and convert to grayscale (single channel) + img = Image.open(self.images[idx]).convert('L') # 'L' mode = grayscale + + # Apply transforms if provided (resize, normalize, etc.) + if self.transform: + img = self.transform(img) + return img + +# ----------------------------- +# Data Path Configuration +# ----------------------------- +# Get the directory where this script is located +script_dir = os.path.dirname(os.path.abspath(__file__)) # folder of this script + +# Navigate to the data directory (two levels up from scripts folder) +data_base = os.path.abspath(os.path.join(script_dir, "../../keras_png_slices_data")) + +# Define paths to train, validation, and test data folders +train_folder = os.path.join(data_base, "keras_png_slices_train") +validate_folder = os.path.join(data_base, "keras_png_slices_validate") +test_folder = os.path.join(data_base, "keras_png_slices_test") + +# Verify the data paths exist +print("Train folder:", train_folder) +print("Exists?", os.path.exists(train_folder)) + +# ----------------------------- +# Image Preprocessing Pipeline +# ----------------------------- +transform = transforms.Compose([ + transforms.Resize((64,64)), # Resize all images to 64x64 pixels for consistency + transforms.ToTensor(), # Convert PIL image to PyTorch tensor (0-1 range) + transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] range (better for VAE training) +]) + +# ----------------------------- +# Dataset Creation +# ----------------------------- +# Create dataset objects for train, validation, and test sets +train_dataset = MRIDataset(train_folder, transform=transform) +val_dataset = MRIDataset(validate_folder, transform=transform) +test_dataset = MRIDataset(test_folder, transform=transform) + +# Display dataset sizes +print(f"Training samples: {len(train_dataset)}") +print(f"Validation samples: {len(val_dataset)}") +print(f"Test samples: {len(test_dataset)}") + +# ----------------------------- +# Data Visualization +# ----------------------------- +# Display a sample of training images to verify data loading +plt.figure(figsize=(8,8)) +for i in range(4): + plt.subplot(2,2,i+1) + # .squeeze() removes the channel dimension for display (1,64,64) -> (64,64) + plt.imshow(train_dataset[i].squeeze(), cmap='gray') + plt.axis('off') +plt.show() + +# ----------------------------- +# DataLoader Creation +# ----------------------------- +# Create DataLoaders for efficient batch processing during training +# DataLoaders handle batching, shuffling, and parallel data loading +train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) # Shuffle for training +val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False) # No shuffle for validation +test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/train_vae.py b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/train_vae.py new file mode 100644 index 000000000..0753e1630 --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/train_vae.py @@ -0,0 +1,143 @@ +""" +VAE Training Script with Validation and Reconstruction Visualization + +PURPOSE: +- Trains a Variational Autoencoder on MRI brain scans +- Tracks training, validation, and test loss +- Visualizes reconstructed images for qualitative evaluation +- Saves model weights after training +""" + +import os +import torch +import torch.nn as nn +import torch.optim as optim +import matplotlib.pyplot as plt +from load_dataset import train_loader, val_loader, test_loader # import all loaders +from vae_model import VAE + +# ----------------------------- +# Device Configuration +# ----------------------------- +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") +print("Using device:", device) + +# ----------------------------- +# Model Initialization +# ----------------------------- +latent_dim = 16 +model = VAE(latent_dim=latent_dim).to(device) + +# ----------------------------- +# VAE Loss Function +# ----------------------------- +def vae_loss(recon_x, x, mu, logvar): + recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum') + kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return recon_loss + kl_div + +# ----------------------------- +# Optimizer Setup +# ----------------------------- +optimizer = optim.Adam(model.parameters(), lr=1e-3) + +# ----------------------------- +# Training Parameters +# ----------------------------- +epochs = 15 + +# ----------------------------- +# Initialize lists to store losses +# ----------------------------- +train_losses = [] +val_losses = [] + +# ----------------------------- +# Training Loop +# ----------------------------- +for epoch in range(1, epochs+1): + model.train() + train_loss = 0 + + for batch in train_loader: + batch = batch.to(device) + optimizer.zero_grad() + recon, mu, logvar = model(batch) + loss = vae_loss(recon, batch, mu, logvar) + loss.backward() + optimizer.step() + train_loss += loss.item() + + avg_train_loss = train_loss / len(train_loader.dataset) + train_losses.append(avg_train_loss) + + # ----------------------------- + # Validation Loop + # ----------------------------- + model.eval() + val_loss = 0 + with torch.no_grad(): + for batch in val_loader: + batch = batch.to(device) + recon, mu, logvar = model(batch) + loss = vae_loss(recon, batch, mu, logvar) + val_loss += loss.item() + avg_val_loss = val_loss / len(val_loader.dataset) + val_losses.append(avg_val_loss) + + print(f"Epoch {epoch}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}") + +# ----------------------------- +# Plot Epoch vs Loss +# ----------------------------- +plt.figure(figsize=(8,5)) +plt.plot(range(1, epochs+1), train_losses, label='Train Loss') +plt.plot(range(1, epochs+1), val_losses, label='Validation Loss') +plt.xlabel("Epoch") +plt.ylabel("Loss") +plt.title("VAE Training and Validation Loss per Epoch") +plt.legend() +plt.show() + +# ----------------------------- +# Save Model +# ----------------------------- +os.makedirs("../models", exist_ok=True) +model_path = "../models/vae_model.pth" +torch.save(model.state_dict(), model_path) +print(f"Model saved to {model_path}") + +# ----------------------------- +# Final Test Evaluation +# ----------------------------- +model.eval() +test_loss = 0 +with torch.no_grad(): + for batch in test_loader: + batch = batch.to(device) + recon, mu, logvar = model(batch) + loss = vae_loss(recon, batch, mu, logvar) + test_loss += loss.item() + +avg_test_loss = test_loss / len(test_loader.dataset) +print(f"Final Test Loss: {avg_test_loss:.4f}") + +# ----------------------------- +# Final Visualization (after training) +# ----------------------------- +with torch.no_grad(): + sample_batch = next(iter(test_loader)).to(device) + recon_batch, _, _ = model(sample_batch) + plt.figure(figsize=(8,4)) + for i in range(4): + # Original + plt.subplot(2,4,i+1) + plt.imshow(sample_batch[i].cpu().squeeze(), cmap='gray') + plt.title("Original") + plt.axis('off') + # Reconstructed + plt.subplot(2,4,i+5) + plt.imshow(recon_batch[i].cpu().squeeze(), cmap='gray') + plt.title("Reconstructed") + plt.axis('off') + plt.show() diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/vae_model.py b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/vae_model.py new file mode 100644 index 000000000..6d3db916a --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/vae_model.py @@ -0,0 +1,172 @@ +""" +Variational Autoencoder (VAE) Model Architecture + +PURPOSE: This file defines the neural network architecture for a Variational Autoencoder (VAE) +designed to work with MRI brain scan images. The VAE consists of an encoder that compresses +images into a latent space representation, and a decoder that reconstructs images from +latent vectors. + +WHY IT'S NEEDED: +- VAE learns a probabilistic representation of the data in a lower-dimensional latent space +- The encoder maps images to mean and variance parameters of a Gaussian distribution +- The decoder generates new images by sampling from the latent space +- This enables both reconstruction and generation of new MRI-like images +- The reparameterization trick allows for differentiable sampling during training +""" + +import torch +import torch.nn as nn + +class VAE(nn.Module): + """ + Variational Autoencoder for MRI brain scan images. + + This VAE architecture uses convolutional layers for the encoder and transpose + convolutional layers for the decoder, making it well-suited for image data. + """ + + def __init__(self, latent_dim=16): + """ + Initialize the VAE model. + + Args: + latent_dim (int): Dimension of the latent space representation + """ + super(VAE, self).__init__() + + # ----------------------------- + # Encoder Network + # ----------------------------- + # The encoder compresses 64x64 grayscale images into latent representations + # Each conv layer reduces spatial dimensions by half while increasing channels + self.encoder = nn.Sequential( + # Input: 1 channel (grayscale), Output: 32 channels + # Kernel size 4, stride 2, padding 1: 64x64 -> 32x32 + nn.Conv2d(1, 32, 4, 2, 1), + nn.ReLU(), + + # Input: 32 channels, Output: 64 channels + # 32x32 -> 16x16 + nn.Conv2d(32, 64, 4, 2, 1), + nn.ReLU(), + + # Input: 64 channels, Output: 128 channels + # 16x16 -> 8x8 + nn.Conv2d(64, 128, 4, 2, 1), + nn.ReLU(), + + # Flatten the 128x8x8 feature map into a vector + nn.Flatten() + ) + + # ----------------------------- + # Latent Space Mapping + # ----------------------------- + # Map the flattened features to mean and log-variance of Gaussian distribution + # This is the key difference from regular autoencoders - we learn a distribution + self.fc_mu = nn.Linear(128*8*8, latent_dim) # Mean of latent distribution + self.fc_logvar = nn.Linear(128*8*8, latent_dim) # Log-variance of latent distribution + + # ----------------------------- + # Decoder Network + # ----------------------------- + # The decoder reconstructs images from latent vectors + self.fc_decode = nn.Linear(latent_dim, 128*8*8) # Expand latent vector back to feature map size + + self.decoder = nn.Sequential( + # Reshape the vector back to 128x8x8 feature map + nn.Unflatten(1, (128, 8, 8)), + + # Transpose convolutions to upsample and reduce channels + # 8x8 -> 16x16, 128 channels -> 64 channels + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.ReLU(), + + # 16x16 -> 32x32, 64 channels -> 32 channels + nn.ConvTranspose2d(64, 32, 4, 2, 1), + nn.ReLU(), + + # 32x32 -> 64x64, 32 channels -> 1 channel (grayscale) + nn.ConvTranspose2d(32, 1, 4, 2, 1), + + # Tanh activation outputs values in [-1, 1] range (matching our normalization) + nn.Tanh() + ) + + def reparameterize(self, mu, logvar): + """ + Reparameterization trick for VAE training. + + Instead of sampling directly from N(mu, sigma), we sample from N(0,1) and transform: + z = mu + sigma * epsilon, where epsilon ~ N(0,1) + + This makes the sampling process differentiable, allowing gradients to flow through. + + Args: + mu (torch.Tensor): Mean of the latent distribution + logvar (torch.Tensor): Log-variance of the latent distribution + + Returns: + torch.Tensor: Sampled latent vector + """ + # Convert log-variance to standard deviation + std = torch.exp(0.5 * logvar) + + # Sample random noise from standard normal distribution + eps = torch.randn_like(std) + + # Apply reparameterization: z = mu + std * epsilon + return mu + eps * std + + def forward(self, x): + """ + Forward pass through the VAE. + + Args: + x (torch.Tensor): Input images of shape (batch_size, 1, 64, 64) + + Returns: + tuple: (reconstructed_images, mu, logvar) + - reconstructed_images: Generated images + - mu: Mean of latent distribution + - logvar: Log-variance of latent distribution + """ + # ----------------------------- + # Encoding: Image -> Latent Distribution + # ----------------------------- + # Pass through encoder to get feature representation + enc = self.encoder(x) + + # Map features to latent distribution parameters + mu = self.fc_mu(enc) # Mean of latent Gaussian + logvar = self.fc_logvar(enc) # Log-variance of latent Gaussian + + # ----------------------------- + # Sampling: Latent Distribution -> Latent Vector + # ----------------------------- + # Sample from the learned distribution using reparameterization trick + z = self.reparameterize(mu, logvar) + + # ----------------------------- + # Decoding: Latent Vector -> Reconstructed Image + # ----------------------------- + # Expand latent vector and pass through decoder + out = self.decoder(self.fc_decode(z)) + + return out, mu, logvar + +# ----------------------------- +# Model Testing and Verification +# ----------------------------- +if __name__ == "__main__": + # Test the model with random input to verify architecture + model = VAE(latent_dim=16) + x = torch.randn(2, 1, 64, 64) # batch of 2 grayscale images + + # Forward pass + recon, mu, logvar = model(x) + + # Print shapes to verify everything works correctly + print("Input shape:", x.shape) + print("Reconstructed shape:", recon.shape) + print("Latent mu shape:", mu.shape) diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/vae_visualise.py b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/vae_visualise.py new file mode 100644 index 000000000..c5cdcbf17 --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/scripts/vae_visualise.py @@ -0,0 +1,100 @@ +""" +VAE Visualization and Generation Script + +PURPOSE: This script loads a trained VAE model and generates new MRI-like images by sampling +from the learned latent space. It demonstrates the generative capabilities of the VAE by +creating synthetic brain scan images that follow the patterns learned during training. + +WHY IT'S NEEDED: +- Demonstrates the generative power of the trained VAE model +- Shows how the model learned to map random noise to realistic MRI images +- Provides visual verification that training was successful +- Generates new synthetic data that could be used for data augmentation +- Illustrates the continuous nature of the learned latent space +""" + +import os +import torch +import matplotlib.pyplot as plt +from vae_model import VAE # Import our VAE model architecture + +# ----------------------------- +# Device Configuration +# ----------------------------- +# Use the same device configuration as training (GPU if available) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# ----------------------------- +# Model Loading +# ----------------------------- +# Initialize the VAE model with the same architecture as training +latent_dim = 16 # Must match the latent dimension used during training +model = VAE(latent_dim=latent_dim).to(device) + +# Load the trained model weights from the saved checkpoint +model_path = "../models/vae_model.pth" +if os.path.exists(model_path): + model.load_state_dict(torch.load(model_path, map_location=device)) + print(f"Successfully loaded model from {model_path}") +else: + print(f"Error: Model file not found at {model_path}") + print("Please run train_vae.py first to train and save the model.") + exit(1) + +# Set model to evaluation mode (disables dropout, batch norm uses running stats) +model.eval() + +# ----------------------------- +# Image Generation +# ----------------------------- +# Number of images to generate +num_samples = 16 + +print(f"Generating {num_samples} new MRI images...") + +# Generate new images by sampling from the latent space +with torch.no_grad(): # Disable gradient computation for inference + # Sample random latent vectors from standard normal distribution N(0,1) + # This is the key insight: we can generate new images by sampling from the + # learned latent space distribution + z = torch.randn(num_samples, latent_dim).to(device) + + # Pass the random latent vectors through the decoder to generate images + # We bypass the encoder and directly use the decoder part of the VAE + generated = model.decoder(model.fc_decode(z)) + +print("Image generation completed!") + +# ----------------------------- +# Visualization +# ----------------------------- +# Create a grid layout to display all generated images +plt.figure(figsize=(8, 8)) +plt.suptitle("Generated MRI Samples from VAE", fontsize=16, fontweight='bold') + +# Display each generated image in a 4x4 grid +for i in range(num_samples): + plt.subplot(4, 4, i + 1) + + # Convert tensor to numpy and remove channel dimension for display + # .squeeze() removes the channel dimension: (1, 64, 64) -> (64, 64) + # .cpu() moves tensor from GPU to CPU for matplotlib + img = generated[i].squeeze().cpu() + + # Display the image in grayscale + plt.imshow(img, cmap='gray') + plt.axis('off') # Remove axes for cleaner display + +# Adjust layout to prevent overlapping +plt.tight_layout() + +# Show the plot +plt.show() + +print("Visualization completed!") +print("\nWhat you're seeing:") +print("- Each image is generated from a random 16-dimensional vector") +print("- The VAE learned to map these random vectors to realistic MRI-like images") +print("- The diversity shows the model learned a rich representation of brain anatomy") +print("- These are completely synthetic images, not reconstructions of training data") diff --git a/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/train.py b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/train.py new file mode 100644 index 000000000..ca503a3c9 --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/OASIS-ImprovedUNet/train.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +import matplotlib.pyplot as plt +from modules import ImprovedUNet +from dataset import load_oasis_data + +# CONFIG +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +BATCH_SIZE = 8 +EPOCHS = 35 +LR = 1e-3 + +# Data loading: update paths as needed +train_images, train_labels = load_oasis_data( + '/home/groups/comp3710/OASIS/keras_png_slices_train', + '/home/groups/comp3710/OASIS/keras_png_slices_seg_train' +) +val_images, val_labels = load_oasis_data( + '/home/groups/comp3710/OASIS/keras_png_slices_validate', + '/home/groups/comp3710/OASIS/keras_png_slices_seg_validate' +) + +class OasisDataset(Dataset): + def __init__(self, images, labels): + self.images = torch.tensor(images, dtype=torch.float32).unsqueeze(1)/255.0 + self.labels = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)/255.0 + + def __len__(self): + return self.images.shape[0] + + def __getitem__(self, idx): + return self.images[idx], self.labels[idx] + +train_ds = OasisDataset(train_images, train_labels) +val_ds = OasisDataset(val_images, val_labels) +train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) +val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) + +model = ImprovedUNet(n_channels=1, n_classes=1).to(device) +loss_fn = nn.BCEWithLogitsLoss() +optimizer = optim.Adam(model.parameters(), lr=LR) + +train_losses, val_losses = [], [] + +for epoch in range(EPOCHS): + model.train() + running_loss = 0.0 + for imgs, masks in train_dl: + imgs, masks = imgs.to(device), masks.to(device) + optimizer.zero_grad() + outputs = model(imgs) + loss = loss_fn(outputs, masks) + loss.backward() + optimizer.step() + running_loss += loss.item() * imgs.size(0) + avg_train_loss = running_loss / len(train_dl.dataset) + train_losses.append(avg_train_loss) + + model.eval() + running_vloss = 0.0 + with torch.no_grad(): + for vimgs, vmasks in val_dl: + vimgs, vmasks = vimgs.to(device), vmasks.to(device) + voutputs = model(vimgs) + vloss = loss_fn(voutputs, vmasks) + running_vloss += vloss.item() * vimgs.size(0) + avg_val_loss = running_vloss / len(val_dl.dataset) + val_losses.append(avg_val_loss) + print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}') + + # Save checkpoint + torch.save(model.state_dict(), f"unet_epoch{epoch+1}.pth") + +# Plot losses +plt.plot(train_losses, label='Train') +plt.plot(val_losses, label='Val') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.legend() +plt.savefig('loss_plot.png') +plt.show() diff --git a/OASIS-Improved-UNet-s4802308/README.md b/OASIS-Improved-UNet-s4802308/README.md new file mode 100644 index 000000000..0090e457f --- /dev/null +++ b/OASIS-Improved-UNet-s4802308/README.md @@ -0,0 +1,185 @@ +# Improved UNet for OASIS Brain MRI Segmentation + +**Student Name:** Oliver McCarthy +**Student Number:** 48023083 + +## Introduction + +This assignment presents a solution to the problem of biomedical image segmentation, specifically segmenting the 2D OASIS brain dataset using an Improved UNet architecture in PyTorch. The task demonstrates practical skills in deep learning for medical imaging, software engineering in version control, and reproducible research reporting. The goal is to achieve accurate segmentation on preprocessed brain MRI slices while documenting the workflow and outcomes. + +--- + +## Problem Description + +The task is to segment anatomical regions from 2D brain MRI slices. Images are sourced from the OASIS dataset, with corresponding ground-truth segmentation masks provided. Accurate segmentation of medical images is critical for quantitative analysis in neuroimaging research and clinical diagnosis. + +--- + +## Project Structure + +recognition/ +└── OASIS-ImprovedUNet/ + ├── dataset.py # Loads and preprocesses OASIS MRI & label data + ├── modules.py # Improved UNet model architecture (PyTorch) + ├── train.py # Model training and validation loop + ├── predict.py # Example inference script + result visualisations + └── README.md # Documentation and usage instructions + +--- + +## Algorithm Summary + +The Improved UNet model is a convolutional neural network with an encoder–decoder structure optimised for semantic segmentation tasks. It processes grey-scale brain MRI slices and produces corresponding binary segmentation masks. The model is trained using binary cross-entropy loss with logits and tracks training and validation loss across epochs. + +--- + +## Workflow and Usage Instructions + +### Code Structure + +- **`modules.py`** – Implements the Improved UNet architecture components. +- **`dataset.py`** – Loads and preprocesses OASIS image and mask data from PNG files. +- **`train.py`** – Handles model training, validation, checkpoint saving, and loss plotting. +- **`predict.py`** – Loads trained checkpoints and generates segmentation predictions with visualisations. + +### Pipeline + +1. **Data Preparation** + - OASIS 2D slice images and segmentation masks (stored as PNGs) are loaded and normalised. + - Training, validation, and testing splits follow the assignment convention. + +2. **Training** + - The model trains for a fixed number of epochs. + - Training and validation losses are tracked to monitor overfitting. + +3. **Prediction & Visualisation** + - The best checkpoint is used to generate segmentation masks on test images. + - Predicted outputs are visualised alongside the ground truth. + +### Reproducible Setup + +1. Install and activate a Python 3.11+ virtual environment inside the project directory. + +python3 -m venv myenv +source myenv/bin/activate + +2. Install dependencies +pip install torch matplotlib imageio numpy + +3. Run Model Training +python3 train.py + +4. Run prediction and generate output images +python3 predict.py + +5. Plots and prediction images will be saved in the directory; download as needed. + +--- + +## Data & Splitting Rationale +- Dataset: The OASIS 2D slice data is provided and preprocessed for direct usage. +- Splits: Training, validation and test sets follow the supplied partitioning to ensure consistency and fair evaluation. + +--- + +## Dependencies +- Python 3.11+ +- PyTorch +- matplotlib +- numpy +- imageio + +--- + +## Limitations & Improvements + +- The current implementation applies the Improved UNet to the supplied OASIS dataset only, which meets the easy difficulty criteria for this assignment. +- The dataset is normalised and reshaped for single-channel input; no additional data augmentation is used. +- The project could be extended to more challenging tasks such as segmentation on the HipMRI dataset, use of 3D models, or implementation of advanced architectures. +- Additional evaluation metrics (e.g., Dice coefficient) could be implemented for quantitative assessment. + +--- + +## How to Run + +- Activate your Python environment. +- Install required packages. +- Run scripts as described above. +- Visualisations and plots are produced automatically; check your working directory for output files. + +--- + +## Results + +### Training and Validation Loss + +![Loss Plot](loss_plot.png) + +The training and validation loss values, tracked over 15 epochs, are visualised in the loss plot above. The training loss remains consistently low from the start, indicating that the model is rapidly able to fit the training data. In contrast, the validation loss begins substantially higher, reflecting the challenge the model initially faces when generalising to new data. + +Notably, the validation loss decreases sharply after the first few epochs, with the largest improvement occurring between epochs 4 and 5. By epoch 9, the validation loss stabilises at a considerably lower value—much closer to the training loss—demonstrating that the model's performance on unseen data has improved and is no longer dramatically different from its performance during training. + +The convergence and parallel decline of both losses toward the end suggest successful learning, effective model generalisation, and little evidence of overfitting. This pattern, with high initial validation loss followed by rapid convergence, is typical for deep learning models as they transition from learning broad patterns to refining their segmentation boundary details over successive epochs. + +Monitoring both curves ensures the model is not simply memorising the training set, but rather learning generalisable features essential for robust MRI segmentation. + +#### Dice Similarity Coefficient + +The Dice coefficient for segmentation predictions on the test set was **0.6323**, which does not meet the assignment threshold (≥ 0.9). This score reflects that, while the model segments the main brain regions effectively, it struggles with accurately capturing fine anatomical details and boundaries, as seen in both the visual results and the quantitative metric. + + +### Sample Prediction 1 +![Prediction 0](prediction_0.png) + +In Prediction 1, the input MRI slice (left) is clearly segmented with distinct anatomical structures in the ground truth mask (centre). The predicted mask (right) successfully outlines major brain regions and excludes most of the non-brain background. The segmentation captures general tissue contours and shows good agreement with the shape and location of structures in the ground truth. + +However, there are noticeable artifacts, such as ring-like patterns near the outer edge of the mask and some discontinuities inside the brain area. These artifacts likely result from thresholding imperfections, limited post-processing, or model bias toward certain regions seen during training. Fine internal features appear less precisely defined in the prediction—some small structures are missed or blurred compared to the ground truth. + +Overall, the model provides a reasonable binary segmentation for core brain regions but could be improved to reduce edge artifacts and sharpen internal anatomical boundaries. Incorporating additional metrics or data augmentation may help in addressing these issues in future iterations. + +### Sample Prediction 2 +![Prediction 1](prediction_1.png) + +This figure again illustrates the input MRI, the ground truth, and the model’s predicted segmentation. The prediction displays strong overall agreement in identifying the brain region, with the major lobes and boundaries reasonably well matched to the ground truth. + +However, closer inspection reveals several limitations: the model misses some of the finer internal details, leading to thicker and less contoured segmentation of internal regions compared to the true mask. There are still faint ring artifacts, and some central features appear overly connected or blurred. Despite these artifacts, the core region of the brain is successfully segmented, and the model avoids large false positives outside the brain. + +Interpreting these results, the model clearly generalises the broad brain shape but could benefit from additional regularisation, improved post-processing, or refined loss functions to increase the precision of boundary and internal structure segmentation. + +### Sample Prediction 3 +![Prediction 2](prediction_2.png) + +This prediction shows strong alignment between the predicted mask and the main brain tissue in the ground truth, with the overall shape and location of the brain accurately delineated. The segmentation correctly excludes most of the background and captures the gross anatomical structure. + +Nonetheless, similar to prior predictions, the model does not capture every fine internal contour, with some subtle boundaries and intricate regions blurred or missed entirely in the prediction. Artefactual ring patterns remain present at the periphery, and some thin structures present in the ground truth are not fully segmented, possibly due to the network’s preference for large, contiguous foreground regions and limitations in spatial resolution. + +In summary, the result is a clean, faithful segmentation of the brain region, appropriate for a baseline model, but still limited in detail—highlighting opportunities for more advanced modelling or post-processing if higher precision is needed. + +### Prediction Summary + +Each example displays the input brain MRI slice, its ground truth segmentation mask, and the model's predicted mask. The predicted segmentation masks are qualitatively similar to the ground truth, capturing major anatomical structures and providing good separation between tissue and background. + +- Strengths: + - The Improved UNet achieves crisp, binary mask boundaries and extracts the main brain regions with high similarity to the provided ground truth. Most large-scale features and shapes are very well captured, reflecting effective training and model architecture. + +- Limitations: + - Some fine details and small structures are missed, and there are occasional artifacts (e.g., ring-like shapes) at the periphery of the predicted masks. These may result from limitations in model depth, preprocessing, or thresholding method. Additionally, some regions of the brain boundaries look slightly thicker or thinner in predictions than ground truth, likely due to imperfect probability thresholding. + +- Overall: + - The model provides a robust segmentation baseline, with room for improvement through enhanced data augmentation, post-processing, or advanced architectures + +#### Limitations and Troubleshooting + +Despite retraining the model for 35 epochs (instead of 15), monitoring the validation loss, and experimenting with hyperparameters, the Dice similarity coefficient remained below the required value. Potential causes include: +- Insufficient data augmentation or model regulariaation +- Suboptimal thresholding for mask binariaation +- The model’s ability to generalize to fine structures and small regions within the brain masks + +Further improvements could include more extensive data preprocessing, different model architectures, and advanced training techniques. Given time constraints, these enhancements were not implemented, but would be the logical next steps for improving segmentation quality. + + + + + + +