diff --git a/recognition/oasis_unet_timothy_nguyen/README.md b/recognition/oasis_unet_timothy_nguyen/README.md new file mode 100644 index 000000000..f01a79cce --- /dev/null +++ b/recognition/oasis_unet_timothy_nguyen/README.md @@ -0,0 +1,362 @@ +# 2D U-Net to Segment the OASIS Brain MRI Dataset + +## Author +Timothy Nguyen – s4699147 + +## Contents +- [The Task](#the-task) +- [Dependencies](#dependencies) +- [Usage](#usage) +- [U-Net Algorithm](#u-net-algorithm) +- [Training Details](#training-details) +- [Results](#results) +- [Conclusion](#conclusion) +- [References](#references) + +## The Task +This project implements a 2D U-Net model to perform **multi-class brain tissue segmentation** on axial MRI slices from the **OASIS** dataset. The overall aim is to achieve a [high Dice similarity coefficient](https://en.wikipedia.org/wiki/Dice-S%C3%B8rensen_coefficient) (≈0.90) across the OASIS labels using a lightweight, reproducible PyTorch pipeline. + +The dataset used here is the “pre-sliced” 2D version of OASIS that is available on UQ infrastructure or the [OASIS brain study](https://sites.wustl.edu/oasisbrains/) (and can also be mirrored to Google Drive for Colab). Depending on how the data was downloaded, images may be stored in PNG slices or in NIfTI volumes. This project uses only **PNG** slices to simplify preprocessing and reduce dependencies. + +The target segmentation classes for this task are: +- Background +- Gray matter +- White matter +- Ventricular / CSF + +An example image is shown below: + +![Example OASIS MRI image](images/example_oasis_image.png) +![Example OASIS MRI Segment image](images/example_oasis_seg_image.png) + +More information about the original OASIS study can be found on the official [site](https://sites.wustl.edu/oasisbrains/). + +## Dependencies +This project was developed and tested with the following software stack. To minimise “works on my machine” issues, try to stay close to these versions. + +- **Python**: 3.9.23 +- **PyTorch**: 2.5.1 +- **Torchvision**: 0.20.1 +- **NumPy**: 2.0.1 +- **Matplotlib**: 3.9.2 +- **Pillow**: 11.3.0 +- **Nibabel**: 5.3.2 +- **tqdm**: 4.67.1 + +If you are using **conda**, a typical install would be: + +```bash +# Create and activate a new environment +conda create -n comp3710-oasis python=3.9.23 +conda activate comp3710-oasis +# Install GPU-enabled PyTorch + Torchvision (CUDA 12.1) +conda install pytorch==2.5.1 torchvision==0.20.1 pytorch-cuda=12.1 -c pytorch -c nvidia +# Install additional dependencies +conda install numpy==2.0.1 matplotlib==3.9.2 pillow==11.3.0 -c conda-forge +# Optional dependencies (legacy or utilities) +pip install nibabel==5.3.2 tqdm==4.67.1 +``` +Note: Different versions of PyTorch and CUDA can be installed as seen on the [PyTorch website](https://pytorch.org/get-started/locally/). + +## Usage + +### Installation +1. **Clone the repository** +``` +git clone https://github.com/tmthyngyn/PatternAnalysis-2025.git +``` +2. **OPTIONAL: Use Conda to create and use a virtual environment** +``` +conda activate comp3710-oasis +``` +3. **[Install dependencies](#dependencies)** + +4. **Download Brain MRI data** + +If available, access and retrieve the data from Rangpur Path: /home/groups/comp3710/OASIS. + +Otherwise, the data can be downloaded from the OASIS website [here](https://sites.wustl.edu/oasisbrains/). It will need to be separated into training, testing and validation splits. + +Notes: The dataset must be organised in the [canonical](#directory-structure) OASIS directory format for the scripts to function correctly. Please either move/rename/adjust files and scripts accordingly, as seen in [Using the scripts](#using-the-scripts). + +### Directory Structure + +#### Scripts + +``` +├───dataset.py # Dataset loading and preprocessing (PNG backend) +├───modules.py # U-Net model architecture +├───train.py # Training and validation loop +├───predict.py # Model inference and visualisation (includes --scan mode) +├───README.md # Project report and documentation +``` + +#### Images +It is advised to keep images and folders in the format suggested below otherwise the scripts would need to be changed according to specfic structure. See [Using the scripts](#using-the-scripts) for details. + +As the folder names suggest, test, train and validate refer to the testing, training, and validation splits of the data, respectively. +``` +OASIS/ + ├── train/ + │ ├── images/ # input PNGs + │ └── labels/ # integer masks (0:BG, 1:CSF, 2:GM, 3:WM) + ├── val/ + │ ├── images/ + │ └── labels/ + └── test/ + ├── images/ + └── labels/ +``` + +Each of these should in turn contain two folders named images and labels. The images folders hold the grayscale MRI slice images in .png format, while the labels folders contain the segmentation masks in .png format with integer values representing the four classes: + +| Label | Class | Description | +| :----: | ------------ | ----------------------------------- | +| **C0** | Background | Non-brain regions outside the skull | +| **C1** | CSF | Cerebrospinal fluid | +| **C2** | Gray Matter | Outer cortical layer | +| **C3** | White Matter | Inner myelinated tissue | + +You can learn more about the dataset at the [OASIS project page](https://sites.wustl.edu/oasisbrains/). + +### Using the Scripts + +1. **Before running the scripts** + +Before running any of scripts remeber to [clone](#usage) the repository and change into the project directory. + +``` +cd PatternAnalysis-2025/recognition/oasis_unet_timothy_nguyen +``` + +Make sure your dataset is correctly placed in the expected directory structure and that you are in the correct working directory. +Within this folder, you should create the canonical OASIS directory that contains three subdirectories: train, val, and test. If you want everything self-contained, place the dataset inside the folder as follow: + +``` +PatternAnalysis-2025/ +└── recognition/ + └── oasis_unet_timothy_nguyen/ + ├── OASIS/ + │ ├── train/ + │ │ ├── images/ + │ │ └── labels/ + │ ├── val/ + │ │ ├── images/ + │ │ └── labels/ + │ └── test/ + │ ├── images/ + │ └── labels/ + ├── dataset.py + ├── modules.py + ├── train.py + ├── predict.py + └── README.md +``` + +Once this structure is in place, the scripts will automatically locate the data based on the --root argument you provide when executing them. If your dataset is located elsewhere, you can use an absolute path by passing it as --root "C:/path/to/OASIS" instead. No changes inside the scripts are required, as the path handling is managed entirely by command-line arguments. + +The script should now be ready to run. + +2. **Training** + +To train the model, open a terminal or command prompt, navigate to the project directory (recognition/oasis_unet_timothy_nguyen), and execute the following command: + +``` +python train.py --root ./OASIS --epochs 12 --batch-size 4 --num-classes 4 +``` + +This command will begin the training process using the training images and labels located under OASIS/train/ and will validate the model’s performance using the data under OASIS/val/. During training, the script will output progress to the console for each epoch, showing both loss and Dice similarity scores for the training and validation sets. When training completes, the model checkpoint and performance plots are automatically saved inside a new directory named trained_models/oasis_unet/. This directory contains a file called best_model.pth which stores the trained model weights, along with two figures, loss.png and dice.png, that illustrate the training and validation curves. For instance, after training for 12 epochs, you might see console output similar to the following: + +``` +[Epoch 012] Train Loss: 0.0885 | Val Loss: 0.2087 | Train Dice: 0.9403 | Val Dice: 0.8897 +Training complete. Best Val Dice: 0.9221. Artifacts saved to: trained_models/oasis_unet +``` + +3. **Predicting** + +Once the model has been successfully trained, you can generate predictions using the predict.py script. The model can be used to visualise the segmentation of a single example from the validation or test dataset. To do this, execute: + +``` +python predict.py --root ./OASIS --ckpt trained_models/oasis_unet/best_model.pth --split val --index 0 +``` + +This command loads the model from the saved checkpoint file best_model.pth and runs inference on a single image specified by the --index argument (in this case, the first image in the validation set). The resulting visualisation, which shows the input MRI slice, its ground-truth segmentation mask, and the model’s prediction side-by-side, will be saved to outputs/prediction_example.png. In the terminal, the script will print the Dice score for each of the four segmentation classes (C0: Background, C1: CSF, C2: GM, C3: WM) as well as the mean Dice score for the chosen sample. + +If you wish to evaluate the model across an entire dataset split and identify the best, worst, and median performing examples, you can enable scan mode using the --scan flag. For instance, to scan all validation images, run: + +``` +python predict.py --root ./OASIS --ckpt trained_models/oasis_unet/best_model.pth --split val --scan +``` + +The script will iterate through all images in the specified split, compute Dice scores for each, and then generate three summary figures in outputs/gallery/: best.png, worst.png, and decent.png. These figures illustrate the input, ground truth, and predicted segmentations for the highest-scoring image, the lowest-scoring image, and one with a median Dice score respectively. Each figure includes the dataset index and the corresponding per-class and mean Dice scores in its title. To evaluate the model’s generalisation performance on unseen data, you can repeat the same command with --split test to analyse the test dataset. + +4. **Review** + +After training and prediction, your folder structure will include additional directories automatically created by the scripts. The final project directory will look like this: +``` +recognition/oasis_unet_timothy_nguyen/ +├── OASIS/ +│ ├── train/ (images/, labels/) +│ ├── val/ (images/, labels/) +│ └── test/ (images/, labels/) +├── dataset.py +├── modules.py +├── train.py +├── predict.py +├── README.md +│ +├── trained_models/ +│ └── oasis_unet/ +│ ├── best_model.pth +│ ├── loss.png +│ └── dice.png +│ +├── outputs/ +│ ├── prediction_example.png +│ └── gallery/ +│ ├── best.png +│ ├── worst.png +│ └── decent.png +│ +└── __pycache__/ +``` + +By following these steps, the entire pipeline—training, evaluation, and prediction—can be run locally without modifying any file paths inside the code. The trained model and visual outputs will be saved automatically, allowing you to easily inspect results. + +## U-Net Algorithm +[U-Net](https://en.wikipedia.org/wiki/U-Net) is a fully convolutional encoder–decoder network designed for dense, pixel-wise segmentation. It was introduced for biomedical imaging by Ronneberger, Fischer, and Brox (2015) and has since become a standard baseline across medical and general computer vision tasks because it couples strong context capture (downsampling path) with precise localisation (upsampling path with skip connections) [Ronneberger et al., 2015](https://arxiv.org/pdf/1505.04597). + +In this project, a 2D U-Net is trained on single-channel PNG slices of brain MRI to predict a four-class semantic mask (background, CSF, gray matter, white matter). The model produces logits of shape (B, C, H, W) with C=4, and the final prediction is obtained by argmax over the channel dimension. + +### Architectural Overview + +Conceptually, U-Net follows a symmetric “U” shape consisting of a contracting path (encoder) and an expanding path (decoder), bridged by a bottleneck resembling a typical [autoencoder](https://en.wikipedia.org/wiki/Autoencoder). The encoder repeatedly applies two small convolutions to enrich features and then downsamples to expand the receptive field. The decoder upsamples to recover resolution, and at each scale it concatenates the upsampled features with the matching encoder features via skip connections; this restores fine detail that would otherwise be lost. + +![Original U-Net architecture](images/u-net-architecture.png) + +Concretely in our implementation (see modules.py): +- Each encoder stage applies two Conv2d → activation layers (optionally with normalisation) followed by a 2×2 downsampling (e.g., max-pool or stride-2 conv). Channels typically double as resolution halves (e.g., 32→64→128…). +- Each decoder stage upsamples by a factor of two (transpose convolution or interpolation+conv), concatenates the corresponding encoder features (the skip), and applies two Conv2d → activation layers. Channels typically halve as resolution doubles. +- A final 1×1 convolution maps features to C=4 class logits so that the output has the same spatial size as the input. + +## Training Details + +### Why U-Net fits brain-MRI tissue segmentation + +Brain-MRI tissue classes exhibit subtle intensity differences and smooth boundaries. The encoder aggregates global context that disambiguates tissue appearance, while skip connections supply high-frequency detail for accurate boundaries (e.g., CSF ventricles or GM/WM interfaces). This architecture is data-efficient (important for medical datasets) and works well in 2D slice-wise settings, which keeps compute and memory demands modest. + +### Training objective and inference + +Training minimises multi-class cross-entropy over logits (B, C, H, W), optionally with inverse-frequency class weights derived from the training masks to counter the dominance of the background class. + +The Dice Similarity Coefficient (DSC) was used as the primary evaluation metric to measure the overlap between predicted and ground-truth segmentations. Alternative loss formulations such as Dice loss were later popularised for volumetric segmentation in [V-Net](https://arxiv.org/pdf/1606.04797). + +The metric is defined for each class \( c \) as: + +$$ +\mathrm{Dice}_c = \frac{2\,|P_c \cap T_c|}{|P_c| + |T_c|} += \frac{2 \sum_i \mathbb{1}[\hat{y}_i=c] \mathbb{1}[y_i=c]} +{\sum_i \mathbb{1}[\hat{y}_i=c] + \sum_i \mathbb{1}[y_i=c]} +$$ + +At inference, the model’s logits are converted to a discrete mask by argmax over channels. No CRFs or post-processing are applied in this baseline to keep the pipeline simple and reproducible. + +### Architecture realised in this project + +The project uses a 2D U-Net (slice-wise) with single-channel input (1, H, W) and four output classes. Slice intensities are z-scored per image in dataset.py. Labels are remapped to a compact range {0,1,2,3} if needed so that the loss and metrics align with --num-classes 4. Padding is used so that encoder/decoder feature maps align without cropping; the final prediction has identical height and width to the input slice. + +### Data preprocessing and splits + +All training operates on PNG slices in the canonical OASIS layout (train/, val/, test/, each with images/ and labels/). Inputs are normalised by per-slice z-score. The training split drives learning, the validation split monitors generalisation and selects the best checkpoint, and the test split is reserved for the final, unbiased report of performance. This separation avoids information leakage and mirrors standard medical-imaging practice. + +## Results + +### Training Metrics +During training, the model optimised the cross-entropy loss between the predicted and true segmentation masks, with the objective of maximising the Dice Similarity Coefficient (DSC) across all tissue classes (background, CSF, gray matter, and white matter). The training and validation metrics for each epoch are plotted below. + +![Cross Entropy loss versus epoch for both training and validation sets.](images/loss.png) + +![Mean Dice coefficient versus epoch for both training and validation sets.](images/dice.png) + +As seen in the figures above, the cross-entropy loss steadily decreased for both training and validation, while the mean Dice scores improved consistently over the 12 epochs. The validation Dice curve shows slight oscillations due to the relatively small dataset size and the sensitivity of the Dice metric to small segmentation variations; however, the overall trend is strongly increasing, suggesting robust learning and minimal overfitting. + +At the end of training, the model achieved an average validation Dice score of 0.9747, with the following per-class performance: + +| Class | Label | Dice Score | +| :----: | :----------- | :--------: | +| **C0** | Background | **0.9987** | +| **C1** | CSF | **0.9536** | +| **C2** | Gray Matter | **0.9663** | +| **C3** | White Matter | **0.9802** | + +This indicates excellent segmentation accuracy across all classes, with particularly strong performance on the gray and white matter regions, which typically represent the most complex structures in brain tissue segmentation. + +### Outputs + +The final trained U-Net model outputs a 4-channel segmentation mask corresponding to the four tissue classes. Each pixel in the mask represents the most probable class as determined by the network’s logits. The model successfully learned to separate the key anatomical structures of the brain while maintaining sharp boundaries between gray and white matter regions. + +At the end of training, all model artefacts were automatically saved in trained_models/oasis_unet/, including: +- best_model.pth — the trained model checkpoint, +- loss.png and dice.png — the training and validation curves shown above. + +The final model achieved a mean validation Dice score of 0.9747, which reflects highly consistent segmentation performance across the dataset. + +### Example Prediction + +Example visualisations generated by predict.py illustrate how the model performs on unseen validation slices. The triplets below show the input MRI slice, the ground truth segmentation, and the predicted segmentation for several representative samples. Dice scores are shown for each class and the overall mean Dice. + +#### Figure 1: Single Validation Example + +![Prediction Single Validation Example](images/prediction_example.png) +``` +Per-class Dice: +C0: 0.9982, C1: 0.9643, C2: 0.9455, C3: 0.9611 +Mean Dice: 0.9673 +``` +This sample demonstrates the model’s strong ability to delineate the ventricles (CSF) and gray/white matter regions, with minimal leakage between classes. + +#### Figure 2: Worse Performance Example + +![Worst](images/worst.png) +``` +Worst — idx 346 | mean Dice: 0.9360 +Per-class: C0: 0.999, C1: 0.861, C2: 0.935, C3: 0.949 +``` +Even in the worst-case sample, the mean Dice remains above 0.93, indicating that the model generalises well and fails gracefully when encountering more complex or noisier slices. + +#### Figure 3: Decent Performance Example + +![Decent](images/decent.png) +``` +Decent — idx 204 | mean Dice: 0.9736 +Per-class: C0: 0.999, C1: 0.929, C2: 0.979, C3: 0.987 +``` +This mid-performing example represents the typical quality of segmentation achieved across most validation slices. + +#### Figure 4: Best Performance Example + +![Best](images/best.png) +``` +Best — idx 913 | mean Dice: 0.9865 +Per-class: C0: 0.999, C1: 0.974, C2: 0.981, C3: 0.992 +``` +The best-performing slice shows nearly perfect agreement between the predicted and ground-truth masks, especially in the cortical and subcortical boundaries where U-Net’s skip connections preserve fine detail. + +### Summary + +The dataset’s high spatial consistency, clear tissue boundaries, and balanced train/validation/test splits enabled the model to generalise effectively. The trained U-Net achieved a mean Dice score of 0.97 on the validation set, with per-class scores of 0.9987 (Background), 0.9536 (CSF), 0.9663 (Gray Matter), and 0.9802 (White Matter). Even the lowest-performing slices maintained Dice scores above 0.93, highlighting the model’s generalisation ability across diverse brain anatomies. These results show that the OASIS dataset provides sufficient quality and diversity for reliable model training and serves as an effective benchmark for evaluating medical image segmentation performance. + +## Conclusion +This project successfully applied a 2D U-Net architecture to the OASIS brain MRI dataset for multi-class tissue segmentation. +Despite the dataset’s inherent class imbalance—where background and white matter dominate and CSF regions are comparatively sparse—the model demonstrated strong generalisation and anatomical accuracy. Through cross-entropy optimisation and balanced class weighting, the network achieved a mean Dice score of 0.97, with per-class scores exceeding 0.95 for all major tissues. +These results confirm that the OASIS dataset provides sufficient quality and variation to support robust segmentation learning, and that the U-Net architecture remains highly effective for biomedical image analysis. Future extensions could include adopting a 3D U-Net to exploit volumetric context, incorporating more advanced augmentation strategies to address class imbalance, or integrating a hybrid Dice–cross-entropy loss to further refine performance on smaller regions such as cerebrospinal fluid. + +## References +- Ronneberger, O., Fischer, P., & Brox, T. (2015). *U-Net: Convolutional Networks for Biomedical Image Segmentation.* MICCAI. [arXiv:1505.04597](https://arxiv.org/abs/1505.04597) +- Milletari, F., Navab, N., & Ahmadi, S.-A. (2016). *V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation.* 3DV. [arXiv:1606.04797](https://arxiv.org/abs/1606.04797) +- Sørensen, T. (1948). *A Method of Establishing Groups of Equal Amplitude in Plant Sociology Based on Similarity of Species Content.* *Biologiske Skrifter*, 5(4), 1–34. +- Dice, L. R. (1945). *Measures of the Amount of Ecologic Association Between Species.* *Ecology*, 26(3), 297–302. +- U-Net – *Wikipedia entry.* [https://en.wikipedia.org/wiki/U-Net](https://en.wikipedia.org/wiki/U-Net) + + + diff --git a/recognition/oasis_unet_timothy_nguyen/dataset.py b/recognition/oasis_unet_timothy_nguyen/dataset.py new file mode 100644 index 000000000..bb01095d9 --- /dev/null +++ b/recognition/oasis_unet_timothy_nguyen/dataset.py @@ -0,0 +1,276 @@ +""" +Dataset and preprocessing utilities for OASIS PNG slices. + +Overview +-------- +This module provides a PyTorch Dataset class to load 2D axial slices +from the OASIS brain MRI dataset. It expects a canonical folder structure +with separate 'images/' and 'labels/' subfolders for each data split +(train/val/test). The dataset supports PNG image files for easy use on +Colab and local systems without NIfTI dependencies. +""" +import os +import re +import glob +from typing import List, Tuple, Optional + +import numpy as np +import torch +from torch.utils.data import Dataset +from PIL import Image + +# --------------------------------------------------------------------- +# Expected canonical layout for the dataset +# +# OASIS/ +# train/ +# images/ +# labels/ +# val/ +# images/ +# labels/ +# test/ +# images/ +# labels/ +# --------------------------------------------------------------------- + +EXPECTED_SPLITS = ("train", "val", "test") + +def _enforce_oasis_layout(root: str) -> None: + """Ensure canonical OASIS/ layout with required split folders.""" + missing = [] + # Verify that all expected split folders exist + for split in EXPECTED_SPLITS: + split_dir = os.path.join(root, split) + if not os.path.isdir(split_dir): + missing.append(f"{split}/") + if missing: + # Construct a detailed error message if any split folders are missing + msg = [ + f"[OASIS layout error] Expected canonical layout under: {os.path.abspath(root)}", + "", + "Required folder structure:", + " OASIS/", + " train/images/", + " train/labels/", + " val/images/", + " val/labels/", + " test/images/", + " test/labels/", + "", + "Missing split folders:", + ] + [f" - {m}" for m in missing] + raise FileNotFoundError("\n".join(msg)) + +# --------------------------------------------------------------------- +# PNG pairing helper functions +# --------------------------------------------------------------------- + +# Regular expressions to detect file naming patterns +# Accept filenames like: +# images: case_367_slice_20.nii.png OR case_367_slice_20.png +# labels: seg_367_slice_20.nii.png OR seg_367_slice_20.png +_RX_NII_PNG = re.compile( + r"^(?Pcase|img|image|seg|label)?_?(?P\d+)_slice_(?P\d+)\.nii\.png$", + re.IGNORECASE, +) +_RX_PNG = re.compile( + r"^(?Pcase|img|image|seg|label)?_?(?P\d+)_slice_(?P\d+)\.png$", + re.IGNORECASE, +) + +def _list_pngs(d: str) -> List[str]: + """Return all .png and .PNG files sorted alphabetically.""" + return sorted(glob.glob(os.path.join(d, "*.png")) + glob.glob(os.path.join(d, "*.PNG"))) + +def _parse_png_key(path: str) -> Tuple[Optional[str], Optional[bool]]: + """Return (key, is_label) from a PNG filename or (None, None) if no match.""" + b = os.path.basename(path) + # Match either .nii.png or plain .png naming pattern + m = _RX_NII_PNG.match(b) or _RX_PNG.match(b) + if not m: + return None, None + pid = m.group("pid") # patient ID + sid = m.group("sid") # slice ID + prefix = (m.group("prefix") or "").lower() + is_label = prefix in {"seg", "label"} # determine if file is a label + return f"{pid}_{sid}", is_label # unique slice key + +def _pair_pngs(images_dir: str, labels_dir: str) -> List[Tuple[str, str]]: + """Pair each image with its corresponding label by filename key.""" + img_paths = _list_pngs(images_dir) + lbl_paths = _list_pngs(labels_dir) + # Index image and label files by their parsed key + by_key_img, bad_img = {}, [] + for p in img_paths: + k, is_label = _parse_png_key(p) + if k is None or is_label: # ignore mislabelled or invalid files + bad_img.append(os.path.basename(p)) + else: + by_key_img[k] = p + by_key_lbl, bad_lbl = {}, [] + for p in lbl_paths: + k, is_label = _parse_png_key(p) + if k is None or not is_label: # ignore files not marked as labels + bad_lbl.append(os.path.basename(p)) + else: + by_key_lbl[k] = p + # Intersect keys present in both images and labels + common = sorted(set(by_key_img).intersection(by_key_lbl)) + pairs = [(by_key_img[k], by_key_lbl[k]) for k in common] + # Raise an error if no valid pairs found + if not pairs: + msg = [ + "No paired .png files found.", + f"Images dir: {images_dir} (count={len(img_paths)})", + f"Labels dir: {labels_dir} (count={len(lbl_paths)})", + ] + # Provide details for debugging + img_only = sorted(set(by_key_img) - set(by_key_lbl)) + lbl_only = sorted(set(by_key_lbl) - set(by_key_img)) + if img_only: + msg.append("\nImage keys without matching labels (first 10):") + msg += [f" - {k}" for k in img_only[:10]] + if lbl_only: + msg.append("\nLabel keys without matching images (first 10):") + msg += [f" - {k}" for k in lbl_only[:10]] + if bad_img: + msg.append("\nUnparsable / misplaced files in images/ (first 10):") + msg += [f" - {n}" for n in bad_img[:10]] + if bad_lbl: + msg.append("\nUnparsable / misplaced files in labels/ (first 10):") + msg += [f" - {n}" for n in bad_lbl[:10]] + msg.append( + "\nExpected filename patterns like:\n" + " images: case__slice_.nii.png or case__slice_.png\n" + " labels: seg__slice_.nii.png or seg__slice_.png" + ) + raise FileNotFoundError("\n".join(msg)) + # Warn if some files were not paired or are malformed + leftover_img = sorted(set(by_key_img) - set(common)) + leftover_lbl = sorted(set(by_key_lbl) - set(common)) + if leftover_img or leftover_lbl or bad_img or bad_lbl: + print("[warn] PNG: some files were not paired or were unparsable.") + if leftover_img: + print(f" Unpaired images: {len(leftover_img)} (showing up to 5)") + for k in leftover_img[:5]: + print(" -", os.path.basename(by_key_img[k])) + if leftover_lbl: + print(f" Unpaired labels: {len(leftover_lbl)} (showing up to 5)") + for k in leftover_lbl[:5]: + print(" -", os.path.basename(by_key_lbl[k])) + if bad_img: + print(f" Bad image entries: {len(bad_img)} (showing up to 5)") + for n in bad_img[:5]: + print(" -", n) + if bad_lbl: + print(f" Bad label entries: {len(bad_lbl)} (showing up to 5)") + for n in bad_lbl[:5]: + print(" -", n) + + return pairs + + +# --------------------------------------------------------------------- +# Dataset class +# --------------------------------------------------------------------- +class OASIS2DSegmentation(Dataset): + """ + Canonicalised OASIS 2D dataset (PNG-only). + Expects each split to contain 'images/' and 'labels/' folders. + + Returns + ------- + image : (1, H, W) float32 tensor, z-scored if norm=True + mask : (H, W) int64 tensor with labels in [0..num_classes-1] + """ + + def __init__( + self, + root: str = "./OASIS", + split: str = "train", + num_classes: int = 4, + norm: bool = True, + ): + super().__init__() + assert split in EXPECTED_SPLITS, f"split must be one of {EXPECTED_SPLITS}" + self.root = root + self.split = split + self.num_classes = int(num_classes) + self.norm = bool(norm) + # Validate dataset structure + _enforce_oasis_layout(self.root) + # Build absolute paths to image and label folders + img_dir = os.path.join(self.root, split, "images") + lbl_dir = os.path.join(self.root, split, "labels") + # Ensure required subdirectories exist + if not (os.path.isdir(img_dir) and os.path.isdir(lbl_dir)): + raise FileNotFoundError( + f"Missing required subfolders under {split}/. " + f"Expected 'images/' and 'labels/' inside {os.path.join(self.root, split)}." + ) + # Pair images and labels using helper + self.pairs = _pair_pngs(img_dir, lbl_dir) + if not self.pairs: + raise FileNotFoundError(f"No valid image/label pairs found in split '{split}'.") + + def __len__(self) -> int: + """Return number of (image, label) pairs in this dataset split.""" + return len(self.pairs) + + @staticmethod + def _zscore(arr: np.ndarray) -> np.ndarray: + """Apply z-score normalization to image array.""" + m = float(arr.mean()) + s = float(arr.std()) + if s == 0.0: # avoid divide-by-zero + s = 1.0 + return (arr - m) / s + + def _remap_labels(self, mask: np.ndarray) -> np.ndarray: + """ + Map arbitrary integer labels to compact range [0..num_classes-1]. + Extra or unexpected labels are clipped to the last valid index. + """ + uniq = np.unique(mask) # get all label values in the mask + # Create lookup table mapping each unique label to an index + lut = {int(v): min(i, self.num_classes - 1) for i, v in enumerate(uniq)} + # Replace labels using vectorized mapping + out = np.vectorize(lambda v: lut[int(v)])(mask).astype(np.int64) + return out + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Load and return (image, mask) pair at index `idx`.""" + img_path, lbl_path = self.pairs[idx] + # Load grayscale image as float32 + img = np.asarray(Image.open(img_path).convert("L")).astype(np.float32) + # Optionally normalize to zero-mean and unit variance + if self.norm: + img = self._zscore(img) + img = np.expand_dims(img, axis=0) # add channel dimension → (1, H, W) + # Load segmentation mask + mask = np.asarray(Image.open(lbl_path)) + mask = self._remap_labels(mask) # remap labels to consistent range + # Convert to torch tensors + return torch.from_numpy(img), torch.from_numpy(mask).long() + + def calculate_class_weights(self) -> torch.Tensor: + """ + Compute inverse-frequency class weights for this dataset split. + These weights help balance rare vs. common classes during training. + """ + counts = np.zeros(self.num_classes, dtype=np.int64) + # Count label occurrences across all masks + for _, lbl_path in self.pairs: + m = np.asarray(Image.open(lbl_path)) + m = self._remap_labels(m) + vals, cnt = np.unique(m, return_counts=True) + for v, c in zip(vals, cnt): + if v < self.num_classes: + counts[v] += int(c) + # Avoid zero counts to prevent divide-by-zero + counts = np.maximum(counts, 1) + inv = 1.0 / counts.astype(np.float64) + # Normalise so weights sum to num_classes + w = inv / inv.sum() * self.num_classes + return torch.tensor(w, dtype=torch.float32) diff --git a/recognition/oasis_unet_timothy_nguyen/images/best.png b/recognition/oasis_unet_timothy_nguyen/images/best.png new file mode 100644 index 000000000..bd3b5e205 Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/best.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/images/decent.png b/recognition/oasis_unet_timothy_nguyen/images/decent.png new file mode 100644 index 000000000..925748286 Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/decent.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/images/dice.png b/recognition/oasis_unet_timothy_nguyen/images/dice.png new file mode 100644 index 000000000..c6ecbafb2 Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/dice.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/images/example_oasis_image.png b/recognition/oasis_unet_timothy_nguyen/images/example_oasis_image.png new file mode 100644 index 000000000..1934b0da3 Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/example_oasis_image.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/images/example_oasis_seg_image.png b/recognition/oasis_unet_timothy_nguyen/images/example_oasis_seg_image.png new file mode 100644 index 000000000..41bc53abd Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/example_oasis_seg_image.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/images/loss.png b/recognition/oasis_unet_timothy_nguyen/images/loss.png new file mode 100644 index 000000000..578abc9ba Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/loss.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/images/prediction_example.png b/recognition/oasis_unet_timothy_nguyen/images/prediction_example.png new file mode 100644 index 000000000..048be8961 Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/prediction_example.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/images/u-net-architecture.png b/recognition/oasis_unet_timothy_nguyen/images/u-net-architecture.png new file mode 100644 index 000000000..312c59f07 Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/u-net-architecture.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/images/worst.png b/recognition/oasis_unet_timothy_nguyen/images/worst.png new file mode 100644 index 000000000..c621c5cb1 Binary files /dev/null and b/recognition/oasis_unet_timothy_nguyen/images/worst.png differ diff --git a/recognition/oasis_unet_timothy_nguyen/modules.py b/recognition/oasis_unet_timothy_nguyen/modules.py new file mode 100644 index 000000000..d7dfd9262 --- /dev/null +++ b/recognition/oasis_unet_timothy_nguyen/modules.py @@ -0,0 +1,166 @@ +""" +Minimal U-Net model components for 2D medical image segmentation. + +Overview +-------- +This module defines the convolutional blocks and U-Net architecture used for +segmenting 2D slices of MRI data from the OASIS dataset. The model here is +intentionally simple and fully self-contained, making it suitable for +quick experimentation on Colab or local systems. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# --------------------------------------------------------------------- +# Basic building blocks for the U-Net architecture +# --------------------------------------------------------------------- +class DoubleConv(nn.Module): + """ + A common U-Net block consisting of two convolutional layers, + each followed by batch normalization and ReLU activation. + + Parameters + ---------- + in_ch : int + Number of input channels. + out_ch : int + Number of output channels. + """ + + def __init__(self, in_ch, out_ch): + super().__init__() + self.net = nn.Sequential( + # First convolution: reduces aliasing and extracts features + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + # Second convolution: refines features for better localization + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + """Apply two conv-batchnorm-relu layers.""" + return self.net(x) + +# --------------------------------------------------------------------- +# Downsampling and Upsampling helpers +# --------------------------------------------------------------------- +class Down(nn.Module): + """Downscaling with maxpool followed by DoubleConv.""" + + def __init__(self, in_ch, out_ch): + super().__init__() + # Max pooling halves spatial dimensions, DoubleConv increases feature depth + self.net = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_ch, out_ch) + ) + + def forward(self, x): + """Apply maxpool → double convolution.""" + return self.net(x) + + +class Up(nn.Module): + """Upscaling followed by DoubleConv.""" + + def __init__(self, in_ch, out_ch, bilinear=True): + super().__init__() + # Choose between bilinear interpolation or transposed convolution + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_ch, out_ch) + else: + self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_ch, out_ch) + + def forward(self, x1, x2): + """ + Forward pass for the upsampling block. + + Parameters + ---------- + x1 : torch.Tensor + Decoder feature map (after upsampling). + x2 : torch.Tensor + Corresponding encoder feature map for skip connection. + """ + x1 = self.up(x1) # upscale by factor of 2 + # Handle possible size mismatches due to rounding in pooling + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + # Pad if upsampled map is smaller (center crop alignment) + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # Concatenate encoder and decoder features (skip connection) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) # refine fused features + + +class OutConv(nn.Module): + """Final 1×1 convolution to map feature channels to class logits.""" + + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1) + + def forward(self, x): + """Return per-pixel class logits (no activation).""" + return self.conv(x) + +# --------------------------------------------------------------------- +# U-Net architecture +# --------------------------------------------------------------------- +class UNet(nn.Module): + """ + Standard 2D U-Net architecture for semantic segmentation. + + Parameters + ---------- + in_channels : int + Number of channels in input image (e.g., 1 for grayscale MRI slices). + out_channels : int + Number of target segmentation classes. + bilinear : bool + Whether to use bilinear upsampling (True) or transposed convolutions (False). + """ + + def __init__(self, in_channels=1, out_channels=4, bilinear=True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.bilinear = bilinear + # Encoder: progressively reduce spatial resolution while increasing depth + self.inc = DoubleConv(in_channels, 64) # input conv block + self.down1 = Down(64, 128) # 1st downsample + self.down2 = Down(128, 256) # 2nd downsample + self.down3 = Down(256, 512) # 3rd downsample + # If bilinear upsampling, reduce intermediate channels by half + factor = 2 if bilinear else 1 + self.down4 = Down(512, 1024 // factor) + # Decoder: progressively upsample while fusing encoder features + self.up1 = Up(1024, 512 // factor, bilinear) + self.up2 = Up(512, 256 // factor, bilinear) + self.up3 = Up(256, 128 // factor, bilinear) + self.up4 = Up(128, 64, bilinear) + # Final output layer: maps to segmentation logits + self.outc = OutConv(64, out_channels) + + def forward(self, x): + """Forward pass through encoder, bottleneck, and decoder.""" + x1 = self.inc(x) # encode level 1 + x2 = self.down1(x1) # encode level 2 + x3 = self.down2(x2) # encode level 3 + x4 = self.down3(x3) # encode level 4 + x5 = self.down4(x4) # bottleneck + # Decode with skip connections (mirrors encoder) + x = self.up1(x5, x4) # combine bottleneck + encoder-4 + x = self.up2(x, x3) # combine decoder + encoder-3 + x = self.up3(x, x2) # combine decoder + encoder-2 + x = self.up4(x, x1) # combine decoder + encoder-1 + logits = self.outc(x) # produce final logits + return logits # no softmax here (applied in loss/inference) diff --git a/recognition/oasis_unet_timothy_nguyen/predict.py b/recognition/oasis_unet_timothy_nguyen/predict.py new file mode 100644 index 000000000..83451e197 --- /dev/null +++ b/recognition/oasis_unet_timothy_nguyen/predict.py @@ -0,0 +1,265 @@ +""" +Prediction and visualisation for trained 2D U-Net on OASIS PNG slices. + +Overview +-------- +- Loads dataset (OASIS2DSegmentation) in PNG format. +- Rebuilds model from modules.py and loads the saved checkpoint. +- Runs inference on a selected image slice OR the entire split (--scan mode). +- Saves a side-by-side figure showing input, ground truth, and prediction. +- Optionally identifies best, worst, and median Dice predictions when scanning. + +This script helps verify that the model produces reasonable segmentations after training. +""" + +import os +import sys +import argparse +from pathlib import Path +import numpy as np +import torch +import torch.nn as nn +import matplotlib.pyplot as plt +# Local imports +from dataset import OASIS2DSegmentation +import modules + +def dice_per_class_np(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int) -> list[float]: + """Compute per-class Dice for a single predicted mask (numpy arrays).""" + dices = [] + for c in range(num_classes): + t = (y_true == c) + p = (y_pred == c) + inter = np.logical_and(t, p).sum() + denom = t.sum() + p.sum() + dices.append(1.0 if denom == 0 else (2.0 * inter) / denom) + return dices + +def build_model(num_classes: int, device: torch.device) -> nn.Module: + """ + Builds a UNet or UNet2D model from modules.py and sends it to the target device. + + Parameters + ---------- + num_classes : int + Number of output segmentation classes. + device : torch.device + Device to move the model to (CPU or CUDA). + + Returns + ------- + nn.Module + The instantiated model placed on the given device. + """ + if hasattr(modules, "UNet"): + try: + m = modules.UNet(in_channels=1, out_channels=num_classes) + return m.to(device) + except TypeError: + m = modules.UNet().to(device) + return m + if hasattr(modules, "UNet2D"): + try: + m = modules.UNet2D(in_channels=1, out_channels=num_classes) + return m.to(device) + except TypeError: + m = modules.UNet2D().to(device) + return m + raise RuntimeError("No compatible model found in modules.py (expected UNet or UNet2D).") + +def parse_args(): + """ + Parse command-line arguments for inference and visualization. + + Returns + ------- + argparse.Namespace + Parsed arguments for input/output paths and options. + """ + p = argparse.ArgumentParser(description="Predict/visualise using trained UNet on OASIS") + p.add_argument("--root", type=str, default="./OASIS", help="Path to OASIS/ canonical tree") + p.add_argument("--num-classes", type=int, default=4) + p.add_argument("--ckpt", type=str, default="trained_models/oasis_unet/best_model.pth") + p.add_argument("--out", type=str, default="outputs/prediction_example.png") + p.add_argument("--split", type=str, default="val", choices=["train", "val", "test"]) + p.add_argument("--index", type=int, default=0, help="Dataset index to visualise") + p.add_argument("--scan", action="store_true", help="Scan full split to find best/worst/median Dice examples") + return p.parse_args() + +def load_checkpoint(model: nn.Module, ckpt_path: Path): + """ + Loads a saved model checkpoint into the given model. + + Parameters + ---------- + model : nn.Module + Instantiated U-Net model. + ckpt_path : Path + Path to the .pth checkpoint file. + + Returns + ------- + dict + The checkpoint dictionary (contains model_state, optimizer_state, etc.). + """ + ckpt = torch.load(ckpt_path, map_location="cpu") # Load checkpoint to CPU by default + state = ckpt.get("model_state", ckpt) # Extract model state dict if wrapped + model.load_state_dict(state, strict=False) # Load weights into the model + return ckpt + +@torch.no_grad() +def predict_one(model: nn.Module, img: torch.Tensor) -> torch.Tensor: + """ + Perform forward pass on a single image tensor. + + Parameters + ---------- + model : nn.Module + Trained U-Net model. + img : torch.Tensor + Input tensor of shape (1,1,H,W) or (1,H,W). + + Returns + ------- + torch.Tensor + Predicted segmentation mask of shape (H,W) with integer class labels. + """ + if img.ndim == 3: + img = img.unsqueeze(0) # Ensure batch dimension -> (1,1,H,W) + logits = model(img) # Forward pass -> raw logits (1,C,H,W) + pred = logits.argmax(dim=1)[0] # Convert to predicted label map (H,W) + return pred.cpu() # Return prediction on CPU for visualization + +def render_triplet(img: np.ndarray, gt: np.ndarray, pred: np.ndarray, save_path: Path, title: str = ""): + """ + Render a triplet of input image, ground-truth, and prediction. + + Parameters + ---------- + img : np.ndarray + Input grayscale image (H,W). + gt : np.ndarray + Ground truth label mask (H,W). + pred : np.ndarray + Predicted label mask (H,W). + save_path : Path + Path where the resulting visualization will be saved. + title : str, optional + Optional title for figure (used in scan mode). + """ + save_path.parent.mkdir(parents=True, exist_ok=True) + plt.figure(figsize=(12, 4)) + plt.subplot(1, 3, 1) + plt.imshow(img, cmap="gray") + plt.title("Input") + plt.axis("off") + plt.subplot(1, 3, 2) + plt.imshow(gt, interpolation="nearest") + plt.title("Ground Truth") + plt.axis("off") + plt.subplot(1, 3, 3) + plt.imshow(pred, interpolation="nearest") + plt.title("Prediction") + plt.axis("off") + if title: + plt.suptitle(title) + plt.tight_layout(rect=[0, 0, 1, 0.96] if title else None) + plt.savefig(save_path) + plt.close() + +def main(): + """ + Entry point for inference and visualization. + + Steps: + ------ + 1. Parse command-line arguments. + 2. Load the OASIS dataset (PNG backend). + 3. Rebuild the model and load checkpoint weights. + 4. Perform prediction on one example (default). + 5. Or scan full split (--scan) to export best, worst, and median examples. + """ + args = parse_args() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Initialize dataset (checks canonical OASIS layout) + try: + ds = OASIS2DSegmentation(root=args.root, split=args.split, + num_classes=args.num_classes, norm=True) + except FileNotFoundError as e: + print(str(e)) + sys.exit(2) + if len(ds) == 0: + print(f"No data found in split '{args.split}' under {args.root}.") + sys.exit(2) + # Build and load model checkpoint + model = build_model(args.num_classes, device) + ckpt_path = Path(args.ckpt) + if not ckpt_path.exists(): + print(f"Checkpoint not found at: {ckpt_path}") + sys.exit(2) + load_checkpoint(model, ckpt_path) + model.eval() + # -------------------------------------- + # Mode 1: Single example (default) + # -------------------------------------- + if not args.scan: + idx = max(0, min(args.index, len(ds) - 1)) + img_t, gt_t = ds[idx] + img_in = img_t.unsqueeze(0).to(device) + pred_t = predict_one(model, img_in) + + img_np = img_t.squeeze(0).cpu().numpy() + gt_np = gt_t.cpu().numpy() + pred_np = pred_t.cpu().numpy() + + render_triplet(img_np, gt_np, pred_np, Path(args.out)) + per_class = dice_per_class_np(gt_np, pred_np, num_classes=args.num_classes) + print("Per-class Dice (single example):", ", ".join(f"C{c}: {d:.4f}" for c, d in enumerate(per_class))) + print("Mean Dice (single example):", f"{np.mean(per_class):.4f}") + print(f"Saved visualisation to: {args.out}") + return + # -------------------------------------- + # Mode 2: Scan entire split (--scan) + # -------------------------------------- + outdir = Path("outputs/gallery") + outdir.mkdir(parents=True, exist_ok=True) + print(f"Scanning {len(ds)} samples from split '{args.split}'...") + scores = [] + cache = {} + with torch.no_grad(): + for idx in range(len(ds)): + img_t, gt_t = ds[idx] + img_in = img_t.unsqueeze(0).to(device) + pred_t = predict_one(model, img_in) + img_np = img_t.squeeze(0).cpu().numpy() + gt_np = gt_t.cpu().numpy() + pred_np = pred_t.cpu().numpy() + per_class = dice_per_class_np(gt_np, pred_np, args.num_classes) + mean_dice = float(np.mean(per_class)) + scores.append((mean_dice, idx)) + cache[idx] = {"img": img_np, "gt": gt_np, "pred": pred_np, "per_class": per_class} + # Sort and pick best/worst/median + scores.sort(key=lambda x: x[0]) + worst_score, worst_idx = scores[0] + best_score, best_idx = scores[-1] + med_target = np.median([s for s, _ in scores]) + decent_idx = min(scores, key=lambda x: abs(x[0] - med_target))[1] + decent_score = [s for s, i in scores if i == decent_idx][0] + # Save figures + for name, idx, score in [ + ("best", best_idx, best_score), + ("worst", worst_idx, worst_score), + ("decent", decent_idx, decent_score), + ]: + item = cache[idx] + title = f"{name.title()} — idx {idx} | mean Dice {score:.4f} | per-class: {', '.join(f'{x:.3f}' for x in item['per_class'])}" + render_triplet(item["img"], item["gt"], item["pred"], outdir / f"{name}.png", title) + print(f"Saved examples to {outdir}/") + print(f" Best (idx={best_idx}) mean Dice={best_score:.4f}") + print(f" Worst (idx={worst_idx}) mean Dice={worst_score:.4f}") + print(f" Decent (idx={decent_idx}) mean Dice={decent_score:.4f}") + print("Per-class labels: C0=Background, C1=CSF, C2=Gray Matter, C3=White Matter") + + +if __name__ == "__main__": + main() diff --git a/recognition/oasis_unet_timothy_nguyen/train.py b/recognition/oasis_unet_timothy_nguyen/train.py new file mode 100644 index 000000000..4f9835b6b --- /dev/null +++ b/recognition/oasis_unet_timothy_nguyen/train.py @@ -0,0 +1,338 @@ + +""" +Training and validation pipeline for the 2D U-Net model on OASIS PNG slices. + +Overview +-------- +This script ties together the dataset loader, model, optimizer, and metrics. +It supports both UNet and UNet2D models defined in modules.py, computes +per-class Dice metrics, and saves loss/metric curves and checkpoints. The +pipeline is fully runnable on CPU or GPU. +""" + +import os +import sys +import argparse +from pathlib import Path +from typing import Tuple +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt +# Local imports for dataset and model definitions +from dataset import OASIS2DSegmentation +import modules + +# ----------------------------- +# Utilities +# ----------------------------- +def build_model(num_classes: int, device: torch.device) -> nn.Module: + """ + Build a U-Net (or UNet2D) from modules.py and place it on the selected device. + + Parameters + ---------- + num_classes : int + Number of output segmentation classes. + device : torch.device + Target device (CPU or CUDA). + + Returns + ------- + nn.Module + Instantiated model moved to the given device. + + Notes + ----- + - The function tries both UNet and UNet2D constructors. + - If a constructor fails due to missing args, it retries with defaults. + """ + # Prefer the UNet class if present + if hasattr(modules, "UNet"): + try: + m = modules.UNet(in_channels=1, out_channels=num_classes) + return m.to(device) + except TypeError: + # Fallback: if constructor signature differs, call without keyword args + m = modules.UNet().to(device) + return m + # Otherwise, look for UNet2D variant + if hasattr(modules, "UNet2D"): + try: + m = modules.UNet2D(in_channels=1, out_channels=num_classes) + return m.to(device) + except TypeError: + m = modules.UNet2D().to(device) + return m + # Raise an error if neither model is found + raise RuntimeError("No compatible model found in modules.py (expected UNet or UNet2D).") + +@torch.no_grad() +def dice_per_class(pred_logits: torch.Tensor, target: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Compute mean per-class Dice score for a batch. + + Parameters + ---------- + pred_logits : torch.Tensor + Raw model outputs with shape (B, C, H, W). + target : torch.Tensor + Ground truth integer masks with shape (B, H, W). + num_classes : int + Number of segmentation classes. + + Returns + ------- + torch.Tensor + Vector of Dice scores (C,) averaged over the batch. + """ + pred = pred_logits.argmax(dim=1) # convert logits to discrete predictions + dices = [] + eps = 1e-6 # smoothing to avoid divide-by-zero + for c in range(num_classes): + pred_c = (pred == c).float() + targ_c = (target == c).float() + inter = (pred_c * targ_c).sum(dim=(1, 2)) # intersection per image + denom = pred_c.sum(dim=(1, 2)) + targ_c.sum(dim=(1, 2)) + eps + d = (2.0 * inter + eps) / denom # Dice coefficient formula + dices.append(d) + dices = torch.stack(dices, dim=1) # (B,C) + return dices.mean(dim=0) # return average per-class Dice + +def save_curves(save_dir: Path, train_losses, val_losses, train_dices, val_dices): + """ + Save training/validation loss and Dice curves as PNG plots. + + Parameters + ---------- + save_dir : Path + Directory to store curve images. + train_losses, val_losses : list[float] + Lists of loss values per epoch. + train_dices, val_dices : list[float] + Lists of mean Dice per epoch. + """ + save_dir.mkdir(parents=True, exist_ok=True) + # ---- Loss curve ---- + plt.figure() + plt.plot(train_losses, label="train") + plt.plot(val_losses, label="val") + plt.xlabel("Epoch") + plt.ylabel("CrossEntropy Loss") + plt.title("Training/Validation Loss") + plt.legend() + plt.tight_layout() + plt.savefig(save_dir / "loss.png") + plt.close() + # ---- Dice curve ---- + plt.figure() + plt.plot([float(x) for x in train_dices], label="train (mean Dice)") + plt.plot([float(x) for x in val_dices], label="val (mean Dice)") + plt.xlabel("Epoch") + plt.ylabel("Mean Dice") + plt.title("Training/Validation Mean Dice") + plt.legend() + plt.tight_layout() + plt.savefig(save_dir / "dice.png") + plt.close() + +# ----------------------------- +# Training / Validation loops +# ----------------------------- +def train_one_epoch( + model: nn.Module, + optimizer: torch.optim.Optimizer, + criterion: nn.Module, + loader: DataLoader, + device: torch.device, + num_classes: int, +) -> Tuple[float, float]: + """ + Perform a single training epoch. + + Returns + ------- + Tuple[float, float] + (mean_loss, mean_dice) + """ + model.train() + total_loss = 0.0 + total_dice = 0.0 + n_batches = 0 + # iterate over mini-batches + for imgs, masks in loader: + imgs = imgs.to(device, non_blocking=True) + masks = masks.to(device, non_blocking=True) + optimizer.zero_grad(set_to_none=True) + logits = model(imgs) + loss = criterion(logits, masks) + loss.backward() + optimizer.step() + # compute dice for this batch + with torch.no_grad(): + per_class = dice_per_class(logits, masks, num_classes) + mean_dice = float(per_class.mean().item()) + total_loss += float(loss.item()) + total_dice += mean_dice + n_batches += 1 + return total_loss / max(n_batches, 1), total_dice / max(n_batches, 1) + +@torch.no_grad() +def dice_intersections_unions( + logits: torch.Tensor, target: torch.Tensor, num_classes: int +): + """ + Return per-class intersection and (|P| + |T|) sums for a batch. + Accumulate these across batches to compute dataset-level per-class Dice. + """ + # logits -> predicted labels + pred = torch.argmax(logits, dim=1) # (N,H,W) + + # one-hot: (N,C,H,W) + n, h, w = pred.shape + pred_1h = torch.zeros((n, num_classes, h, w), device=pred.device, dtype=torch.float32) + tgt_1h = torch.zeros_like(pred_1h) + pred_1h.scatter_(1, pred.unsqueeze(1), 1.0) + tgt_1h.scatter_(1, target.unsqueeze(1), 1.0) + + inter = (pred_1h * tgt_1h).sum(dim=(0, 2, 3)) # (C,) + sums = pred_1h.sum(dim=(0, 2, 3)) + tgt_1h.sum(dim=(0, 2, 3)) # (C,) + return inter, sums + +@torch.no_grad() +def validate( + model: nn.Module, + criterion: nn.Module, + loader: DataLoader, + device: torch.device, + num_classes: int, +) -> Tuple[float, float]: + """ + Evaluate the model on the validation split. + Prints dataset-level per-class Dice (aggregated correctly across batches). + + Returns + ------- + Tuple[float, float] + (mean_loss, mean_dice) + """ + model.eval() + total_loss = 0.0 + n_batches = 0 + + # Accumulate numerators/denominators for per-class Dice across the dataset + inter_total = torch.zeros(num_classes, device=device, dtype=torch.float64) + sums_total = torch.zeros(num_classes, device=device, dtype=torch.float64) + + for imgs, masks in loader: + imgs = imgs.to(device, non_blocking=True) + masks = masks.to(device, non_blocking=True) + + logits = model(imgs) + loss = criterion(logits, masks) + + total_loss += float(loss.item()) + n_batches += 1 + + inter, sums = dice_intersections_unions(logits, masks, num_classes) + inter_total += inter.to(torch.float64) + sums_total += sums.to(torch.float64) + + eps = 1e-6 + per_class_dice = (2.0 * inter_total + eps) / (sums_total + eps) # (C,) + mean_dice = float(per_class_dice.mean().item()) + avg_loss = total_loss / max(n_batches, 1) + + # Pretty print per-class Dice for this epoch + pcs = per_class_dice.detach().cpu().tolist() + print(" Val per-class Dice:", ", ".join(f"C{c}: {d:.4f}" for c, d in enumerate(pcs))) + + return avg_loss, mean_dice + +# ----------------------------- +# Main entry point +# ----------------------------- +def parse_args(): + """Parse CLI arguments for training configuration.""" + p = argparse.ArgumentParser(description="Train 2D UNet on canonical OASIS layout") + p.add_argument("--root", type=str, default="./OASIS", help="Path to OASIS dataset root") + p.add_argument("--epochs", type=int, default=12, help="Number of epochs to train") + p.add_argument("--batch-size", type=int, default=4, help="Batch size for training/validation") + p.add_argument("--lr", type=float, default=1e-3, help="Learning rate for Adam optimizer") + p.add_argument("--num-classes", type=int, default=4, help="Number of segmentation classes") + p.add_argument("--save-dir", type=str, default="./trained_models/oasis_unet", help="Directory for outputs") + p.add_argument("--num-workers", type=int, default=2, help="Number of DataLoader workers") + p.add_argument("--no-class-weights", action="store_true", help="Disable inverse-frequency class weighting") + return p.parse_args() + +def main(): + """Main training control flow.""" + args = parse_args() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + # Initialize datasets; verify canonical folder layout + try: + train_ds = OASIS2DSegmentation(root=args.root, split="train", + num_classes=args.num_classes, norm=True) + val_ds = OASIS2DSegmentation(root=args.root, split="val", + num_classes=args.num_classes, norm=True) + except FileNotFoundError as e: + print(str(e)) + sys.exit(2) + # DataLoaders for train/val + train_loader = DataLoader( + train_ds, batch_size=args.batch_size, shuffle=True, + num_workers=args.num_workers, pin_memory=True + ) + val_loader = DataLoader( + val_ds, batch_size=args.batch_size, shuffle=False, + num_workers=args.num_workers, pin_memory=True + ) + # Initialize model + model = build_model(args.num_classes, device) + # Optionally compute class weights to balance loss + if args.no_class_weights: + class_weights = None + else: + try: + class_weights = train_ds.calculate_class_weights().to(device) + except Exception: + class_weights = None + criterion = nn.CrossEntropyLoss(weight=class_weights) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # Initialize trackers + best_val_dice = -1.0 + train_losses, val_losses = [], [] + train_dices, val_dices = [], [] + # ---- Training loop ---- + for epoch in range(1, args.epochs + 1): + tr_loss, tr_dice = train_one_epoch(model, optimizer, criterion, train_loader, device, args.num_classes) + va_loss, va_dice = validate(model, criterion, val_loader, device, args.num_classes) + train_losses.append(tr_loss) + val_losses.append(va_loss) + train_dices.append(tr_dice) + val_dices.append(va_dice) + print(f"[Epoch {epoch:03d}] " + f"Train Loss: {tr_loss:.4f} | Val Loss: {va_loss:.4f} | " + f"Train Dice: {tr_dice:.4f} | Val Dice: {va_dice:.4f}") + # Save checkpoint when validation Dice improves + if va_dice > best_val_dice: + best_val_dice = va_dice + ckpt = { + "epoch": epoch, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "val_dice": best_val_dice, + "num_classes": args.num_classes, + } + torch.save(ckpt, save_dir / "best_model.pth") + # Save learning curves + save_curves(save_dir, train_losses, val_losses, train_dices, val_dices) + print(f"Training complete. Best Val Dice: {best_val_dice:.4f}. " + f"Artifacts saved to: {save_dir}") + + +if __name__ == "__main__": + main()