This repository was created as a homework assignment during my Msc. studies in Data Science.
The data set consists of images with different structures which have to be classified. However, instead of a simple label a mask must be predicted (semantic segmentation). For the first network, only conv layers may be used. For the second network the choice of the network architecture is free. I decided to use a U-Net, which was developed by the University of Freiburg, Germany.
Create the conda environment and activate it:
conda env create -f environment.yml
conda activate describable-textures-datasetDownload the images:
python dtd_loader_color_patches.pyArguments:
--tiled: Set toTrueif you want to use the larger colored data set
Also both dataset (tiled and not tiled) can be downloaded. However, in the file config.yaml must be specified which
dataset shall be used for training and testing.
The following plot shows the distribution of the mask:
This distribution is only considered if the BCE with logits loss is used.
Some general configurations are specified in the file config.yaml:
device: eithercudaorcpunum_classes: Number of classes, must be 47 for the dtd datasetmax_num_epoch: Max. number of epochs (or number of epochs without early stopping)loss: The loss function, must be one ofcross-entropy,diceorbce-with-logitstiled:Trueif the tiled dataset shall be used
python train.py [OPTIONS]Arguments:
--learning_rate: The learning rate--batch_size: The batch size--model_name: Name of the model, one of'simple_fcn','simple_u-net'or'pretrained_u-net'--wandb: Set toTrueif you want to use wandb.ai, default is using Tensorboard--early_stopping: Set toTrueif you want to use early stopping
With wandb, it is also possible to run sweeps. First, define the model_name in the
File sweep.yaml and then execute:
wandb sweep sweep.yaml
wandb agent your-sweep-id
python evaluate.py [OPTIONS]Arguments:
--model_name: Name of the model, one of'simple_fcn','simple_u-net'or'pretrained_u-net'--accuracySet toTrueif the Top-1 accuracy on the test set shall be calculated--plotSet toTrueif some predictions shall be plotted
