diff --git a/.gitignore b/.gitignore index fd20fddf8..1670801e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -*.pyc +*.pyc \ No newline at end of file diff --git a/recognition/TRANSFORMER_43909856/README.md b/recognition/TRANSFORMER_43909856/README.md new file mode 100644 index 000000000..dbbbe9590 --- /dev/null +++ b/recognition/TRANSFORMER_43909856/README.md @@ -0,0 +1,300 @@ +# ViT Transformer for image classification of the ADNI dataset + + +## Description +This code can be used to train, validate, and test a ViT (Visual Transformer) model +on the ADNI brain dataset. The ViT model is a binary classifier that takes 2D MRI slice +images as its inputs. It attempts to determine if the patient in the MRI image slice has +either Alzheimer's disease, or a healthy (Cognitive Normal) brain. + + +## Dependencies +This code is written in Python 3.11.5. + +The following libraries/modules are also used: +- pytorch 2.1.0 +- pytorch-cuda 11.8 +- torchvision 0.16.0 +- torchdata 0.7.0 +- matplotlib 3.7.2 +- scikit-learn 1.3.0 +- einops 0.7.0 + +It is strongly recommended that these packages are installed within a new conda +(Anaconda/Miniconda) environment, and that the code is run within this environment. +These libraries can then be installed into the conda environment +using these lines in the terminal: + +``` +conda install pytorch torchvision torchaudio torchdata pytorch-cuda=11.8 -c pytorch -c nvidia + +conda install matplotlib + +conda install scikit-learn + +pip install einops +`````` + +### Reproducibility +Model training was completed on the UQ Rangpur HPC server, using the vgpu40 +node with the following hardware specifications: + +- 8x vCPU cores (AMD Zen 2) +- 128 GB RAM +- NVIDIA A100 40GB vGPU + +For more information, see [UQ EAIT compute infrastructure.](https://student.eait.uq.edu.au/infrastructure/compute/) + +The model was saved after training, and inference was later completed on a local device. + + +## How a ViT works +A ViT (Visual Transformer) is a variation on Transformer models, designed specifically for image processing. + +In a Transformer, input data and target values are given an embedding with a positional encoding. This stores the sequential characteristics/nature of the inputs and targets. In the case of a ViT used for classification, an input image is split into tokens, made up of separated, smaller 'patches' of the original image. Patches generate a positional embedding, based on the position of the patch within the entire image. The target values would be categorical labels, used to classify the input images. + +These are converted into Keys, Queries, and Values, which are fed into multi-head attention modules, followed by small, linear feed-forward networks. An encoded representation of the original inputs is produced, which is fed into other network components (as well as target values). Finally, a softmax layer is applied to the outputs, giving the probabilities of each class being predicted. + +Here is a diagram illustrating the main components of a ViT: +![(Dosovitskiy et al., 2021)](plots/Dosovitskiy_etal_2021_ViT_diagram.PNG) +*(Dosovitskiy et al., 2021)* + +The overall architecture of this ViT was based on the ViT-S/16 model, as mentioned in the 2022 IEEE conference paper "Scaling Vision Transformers" (Zhai et al.). + +### S/16 model details +The S/16 model variant of the ViT was chosen, as it was believed that it would provide a "good tradeoff" between performance and computational speed/efficiency (Beyer et al., 2022). It was believed that this would be the most appropriate, where hardware resources were partly limited. + +Model specs (as specified in Zhai et al, 2022): +- **Patch size:** 16x16 +- **Number of encoder blocks (depth):** 12 +- **Dimensionality of patch embeddings and self-attention modules (width):** 384 +- **Number of attention heads:** 6 +- **Dimension of hidden MLP Linear layers:** 1536 + +The created model differs slightly from the original S/16 ViT model, including modifications as added by Beyere et al. (2022). These include 2-dimensional sinusoids used for positional embeddings, and the use of global average/mean pooling instead of using a class token. + + +## Examples +### Model inputs - ADNI dataset +The input images are 2D MRI slices of a patient's brain, taken from a Alzheimer's +Disease Neuroimaging Initiative (ADNI) dataset. +The images used with this model have been preprocessed. +For more specific details, please see the [ADNI dataset website.](https://adni.loni.usc.edu/) + +The model takes 224x224 colour/RGB images as inputs, where each RGB channel contains +intensity values for each pixel in the range [0, 255]. + +ADNI images are cropped and resized from 256x240 to 224x224. + +![Sample data from the ADNI dataset](plots/ADNI_sample_data.png) + +Each image is also assigned a class label representing whether or not Alzheimer's +disease was observed in the patient. These labels are "AD" (Alzheimer's Detected) +and "NC" (Cognitive Normal). Within the model, these labels are transformed into +numerical categorical values (0 for AD, 1 for NC). + +### Model outputs +A binary classification label is generated by the model, based on the input image. +A 0 value is returned if the AD (Alzheimer's Detected) is predicted, and a 1 value +is returned if predicting the NC (Cognitive Normal) class. + + +## Preprocessing + +### Image resizing +Input images of size 256x240 were resized, then center cropped to a square +resolution of 224x224, which was used as the image input size for the model. +As the images were approximately centred (and all positioned similarly), +processing them in this manner preserved the position of brains in the MRI slices, +whilst resizing them to dimensions that could be evenly downsampled multiple times +by the model network. + +### Normalising the data +The files were loaded as RGB images. These contained 3 channels, in which each +channel represents an intensity value in the range [0, 255]. These values were +standardised such that the means and standard deviations were both changed to +0.5. This placed intensity values for each channel within the range [-1, 1]. + +### Data augmentation +No data augmentation was explicitly applied to the input data. However, all available +MRI image slices for each patient were used in the data set (20 slices per patient). As these slices are +distinct, but all map the same patient's brain (for the same classification), +these slices may act similarly to 'augmented' data. This may result in the model being +more invariant to changes in the different MRI slices provided to it. In some contexts +where additional unseen data (potentially from different datasets) is tested, +the model could potentially be more generalisable as a result. + +### Train, validation, and test splits +The data was split into a training, validation, and test set. +Training set data was used to train the model, with binary cross-entropy loss used +to evaluate its performance throughout the training process. + +During training, the model performance was evaluated on the validation set at the final step of +each epoch. The relationship between training set loss and validation set loss was observed, to note the +points of training in which the model was overfitting or underfitting. The most optimal +length of time for training the model was manually selected, and the model was re-trained with a +different number of epochs. As such, validation set performance was used to perform tuning/selection of +a hyperparameter (the number of epochs). + +The test set was used to evaluate the model performance on unseen data, quantified +by the accuracy metric. + +#### Number of data points +The ADNI dataset contains 1526 patients (30520 MRI image slices). +The test set was composed of data points sampled from the 'test' directory of the ADNI dataset. +This set contained MRI image slices from 223 AD patients (4460 images) and +227 NC patients (4540 images), giving 450 patients (9000 images total). The test set +comprises of roughly 29-30% of the entire dataset. + +Training and validation sets contained points sampled from the 'train' directory of this dataset, +which contains around 70-71% of the data. +80% of 'train' dir data (860 patients, 17200 images) was used in the training set, +with 416 AD patients (8320 images) and 444 NC patients (8880 images). +20% (216 patients, 4320 images) of this data was used in the validation set - +this contained 104 AD patients (2080 images) and 112 NC patients (2240 images). + +#### Justification +The validation set was chosen to be approximately half the size of the test set. +It was considered more beneficial to quantify the model's performance on +a larger selection of unseen data (in the test set), than to utilise more of this +data for the purpose of hyperparameter tuning or training. When more data is moved to the +test set, the distribution of test data more accurately represents the +characteristics of the entire dataset. +The size of the training set was also not decreased in favour of the other sets used, +to allow for the model to train on an appropriate quantity of varying data points. + +The split between the training and validation sets was stratified +(attempting to roughly preserve the class proportions within each split set). +A stratified split can result in more effective training/useful testing. In saying +this, I don't believe that this has made a significant difference to this model +(as the class proportions are almost approximately equal). + +#### Preventing data leakage +The training and validation set data appears to be independent from the test set data, +with no overlapping patient MRI images within both of the 'train' and 'test' data +directories. + +To prevent data leakage within the train and validation sets (split from the 'train' +directory data), the MRI slices were grouped by patient, then the patients were +shuffled and split between each set. After the split, data points (each MRI image +slice) were shuffled within each set. This process ensured that the data was +appropriately shuffled, whilst preventing images from one patient being allocated +to both the train and validation set. + + +## Results: +### Training for 90 epochs: +To perform hyperparameter tuning for the number of epochs, a large number of epochs (90) was initially chosen. The results from the training and validation sets were saved and plotted: + +![90 epochs - training vs. validation loss](plots/ViT_90epochs_loss.png) +Noticeably, the training and validation set loss appears to reach an optimal point between 40 and 60 epochs, then begins to fluctuate significantly at some points after this, despite reaching what appear to be the optimal loss values. + +Validation set accuracy (per-batch) approaches 100% around this training period: + +![90 epochs - validation accuracy](plots/ViT_90epochs_validation_accuracy.png) + + These values could appear optimal on paper. However, the test set performance for the model trained for 90 epochs was ~60.32%, indicating that the model is overfitting significantly to the training set, and generalising poorly to the test set. + + ![90 epochs - confusion matrix on test set](plots/ViT_90epochs_test_confusion_matrix.png) + +The confusion matrix for the model (based on the test set predictions) shows that +the model makes a significant number of incorrect predictions. Notably, the model predicts high quantities of False Positives (where AD/Alzheimer's Detected is the positive class), and a signficiantly smaller amount of False Negatives (where NC/Cognitive Normal is the negative class). + + To avoid overfitting the model, a model with a shorter training duration of 40 epochs was created and tested. + +### Training for 40 epochs: + +![40 epochs - training vs. validation loss](plots/ViT_40epochs_loss.png) + +The training and validation loss, and the validation set accuracy indicate that the model reaches the most optimal loss values at the end of training, and this model likely suffers from less overfitting as a result. + +![40 epochs - validation accuracy](plots/ViT_40epochs_validation_accuracy.png) + +However, the test set for this model performed only marginally better than the 90 epoch model, with a test accuracy of ~60.98%. This indicates that the model still poorly generalises to unseen data, and this can be seen in the model's confusion matrix: + +![40 epochs - confusion matrix on test set](plots/ViT_40epochs_test_confusion_matrix.png) + +Conversely to the 90 epoch model, many False Negatives and significantly fewer False Positives are predicted. +In some contexts, it may be preferable for a model to predict False Positives over False Negatives. +For medical circumstances, mistakenly predicting the existence of the condition in a healthy person may be preferred over not predicting the condition in a sick person, as high False Negative rates can prevent sick individuals from +receiving timely preventative treatment. Because of this, the 90 epochs model could be a more optimal choice of the two in some contexts. Even though test set accuracy is higher in the 40 epochs model, the difference in test set performance between the two is relatively minor. + +A model trained on 30 epochs was tested, to see if the generalisability of the model could be increased further. + +### Training for 30 epochs: + +![30 epochs - training vs. validation loss](plots/ViT_loss.png) + +![30 epochs - validation accuracy](plots/ViT_validation_accuracy.png) + +This model had an accuracy of ~60.37% on the test set - the performance of this model appeared to be somewhere inbetween the 90 epoch and 40 epoch models, although the difference in performance between these models is incredibly minor (less than 1% difference in test accuracy between any of these models). + +![40 epochs - confusion matrix on test set](plots/ViT_test_confusion_matrix.png) + +Similarly to the 40 epochs model, a much higher number of the misclassified images were False Negatives, and much less were false positives. + +As the test accuracy for this model is lower than the 40 epochs model, and this model also predicts high quantites of False Negatives, it would likely not be considered the most optimal model, unless the computational efficiency and faster training time of the model were of significant concern. + + +### Summary +Depending on the context of the model's use, either the 90 epoch or 40 epoch model may be more contextually appropriate. If arbirarily higher test set performance is desired above anything, and faster/more efficient model training is also preferred, then the 40 epoch model would be the most ideal candidate of these three. If the model was required to predict more False Postives than False Negatives, then the 90 epoch model may be the more ideal choice for this. There may also be a more appropriate length of training time (between 40 and 90 epochs) that results in higher performing models than these three. Some of these models may also be the more optimal choice for saving training time and preferencing certain misclassifications over others. + + +## Possible improvements? +The model's performance is not ideal in its current state; the test set results illustrate +that the model is prone to overfitting, and does not generalise well to unseen data. + +The model could be improved on or further experimented with, in multiple different ways: + +### Data augmentation: +The majority of the images featured the brains positioned roughly in the middle, +with an approximate orientation. However, it was noted that the ADNI dataset would +occasionally contain MRI images that were rotated or positioned differently to others. +As the model would be less likely to see this data during the training process, +it may be less invariant to minor rotations and positional changes. To improve the +model performance and make the model more generalisable, it could be regularised by augmenting the training set data. Images could be duplicated within the training set, then augmented with minor rotations and positional translations. Beyer et al. (2022) noted +that ViT models experienced improved performance on baseline image processing datasets +with the use of similar data augmentation techniques. + +### Changing RGB images to greyscale: +The input data is MRI images, in which each pixel of the image is a "greyscale value +that ranges from 0 (pure black) to 255 (pure white)" (Gerber & Peterson, 2008). +Because of this, loading the input data as RGB images results in superfluous +information (the intensity values in each channel are identical). + +Converting the image to a single-channel, greyscale format would likely improve +the computational efficiency of the model during training and inference. +It's also possible that the model could achieve more ideal performance by training on data points with less superfluous data. + +### More hyperparameter tuning: +Whilst the number of epochs/length of training was examined with a validation set, +many other model hyperparameters could additionally be tuned. For example, other +model configurations listed in the Zhai et al. (2022) paper could also be tested on the ADNI dataset, and their performance compared with the S/16 model. +This paper also suggests using either a constant or a reciprocal square-root +learning rate scheduler, to prevent the learning rate from decaying too quickly during training. +This scheduler would include a "warmup" and a "cooldown" period. +Currently, the PyTorch OneCycleLR scheduler is used to "warmup", then "cooldown" the +learning rate in a non-linear fashion. It's possible that a constant or inverse +square-root scheduler may encourage the model to train more effectively, as the learning rate at some points will decay at a slower rate. + +### Other regularisation techniques: +As the model appeared to often suffer from overfitting, some other model-based regularisation techniques such as network dropout could improve the network's generalisability. As well as this, the hyperparameter tuning of the training time (number of epochs) could be utilised to perform early stopping on the model training process, once the loss values have converged to within a given +threshold. + + +## References +- Alzheimer's Disease Neuroimaging Initiative. (2017). *ADNI | Alzheimer's Disease Neuroimaging Initiative.* https://adni.loni.usc.edu/ + +- Beyer, L., Zhai, X., Kolesnikov, A. (2022). *Better plain ViT baselines for ImageNet-1k.* https://arxiv.org/abs/2205.01580 + +- Doshi, K. (2021, January 17). *Transformers explained visually (Part 3): Multi-head attention, deep dive.* https://towardsdatascience.com/ +transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853 + +- Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., Houlsby, N. (2021). An image is worth 16x16 words: Transformers for image recognition at scale. *International Conference on Learning Representations.* https://arxiv.org/pdf/2010.11929.pdf + +- Gerber, A. J., & Peterson, B. S. (2008). What is an image? *Journal of the American Academy of Child and Adolescent Psychiatry*, 47(3), 245–248. https://doi.org/10.1097/CHI.0b013e318161e509 + +- Raschka, S. (2022, June 12). *Taking Datasets, DataLoaders, and PyTorch's new DataPipes for a spin.* https://sebastianraschka.com/blog/2022/datapipes.html#DataPipesforDatasetsWithImagesandCSVs + +- Zhai, X. Kolesnikov, A., Houlsby, N., Beyer, L. (2022). Scaling Vision Transformers. *Institute of Electrical and Electronics Engineers.* https://ieeexplore.ieee.org/document/9880094 + diff --git a/recognition/TRANSFORMER_43909856/dataset.py b/recognition/TRANSFORMER_43909856/dataset.py new file mode 100644 index 000000000..45c3c9384 --- /dev/null +++ b/recognition/TRANSFORMER_43909856/dataset.py @@ -0,0 +1,422 @@ +import os +import os.path as osp +from typing import Dict, List, Tuple +import torch +import torch.nn as nn +from torchvision import transforms +from torch.utils.data import Dataset, DataLoader, random_split, default_collate +from torchvision.datasets import ImageFolder +import matplotlib.pyplot as plt +from torchvision.utils import make_grid +import numpy as np +from torchdata.datapipes.iter import BucketBatcher, FileLister, Mapper, RandomSplitter, UnBatcher +from PIL import Image +from torch.utils.data.backward_compatibility import worker_init_fn +from torchvision.utils import make_grid +import matplotlib.pyplot as plt + +""" +Contains the data loader for loading and preprocessing the ADNI dataset. + +This resource in particular was very useful for creating custom components of the dataset +loading. Some of the code written in this file was based on the general pipeline +followed in the information on this website: +https://sebastianraschka.com/blog/2022/datapipes.html#DataPipesforDatasetsWithImagesandCSVs +""" + + +#### Model hyperparameters: #### +BATCH_SIZE = 32 + + +#### Dataset parameters: #### +# The number of MRI image slices per patient in the dataset +N_IMGS_PER_PATIENT = 20 +# Dimensions to resize the original 256x240 images to (IMG_SIZE x IMG_SIZE) +IMG_SIZE = 224 + + +#### Input processing transforms: #### +# Create basic transforms for the images (using these for now, will need to add other transforms later) +BASIC_TF = transforms.Compose([transforms.ToTensor()]) +''' +Create transforms that resize the image, then crop it to create a 224x224 image. +The transforms will also normalise the RGB intensity values for the data to per-channel +means and standard deviations of 0.5 - this places intensity values in the range +[-1, 1]. +''' +TRAIN_TF = transforms.Compose([ + transforms.Resize(IMG_SIZE), + transforms.CenterCrop(IMG_SIZE), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) +TEST_TF = transforms.Compose([ + transforms.Resize(IMG_SIZE), + transforms.CenterCrop(IMG_SIZE), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) +VAL_TF = transforms.Compose([ + transforms.Resize(IMG_SIZE), + transforms.CenterCrop(IMG_SIZE), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) +# Should validation and test transforms be different? I don't see why they should be + +# TODO could try some data augmentation on these transforms? +# TODO try changing from RGB images to greyscale, compare model performance + + +#### File paths: #### +DATASET_PATH = osp.join("recognition", "TRANSFORMER_43909856", "dataset", "AD_NC") + + +''' +Need to split training set data into a training and validation set. +Need to avoid data leakage -> +Need to group patient MRI image slices by the patient number, and group these +into 'bins' for each patient. +Once this is done, we can then add all slices for each patient to either the +train or the validation set. +Should try do a stratified split of AD and NC class images. + +Total of 743 AD patients and 783 NC patients in train and test sets. + +Test: +AD: +- 4460 images total - 20 MRI slices per patient - 223 patients total? +- MRI slice numbers labelled differently for some patients +- 6 or 7 char patient ID numbers? +- format of image name: 'patientID_MRIslice.jpeg' +NC: +- 4540 images total - 20 MRI slices per patient - 227 patients total? + +Train: +AD: +- 10400 images total - 20 MRI slices per patient - 520 patients total? +NC: +- 11120 images total - 20 MRI slices per patient - 556 patients total? + +Splitting train set into a train and validation set (80/20 stratified split): +- Train: 416 AD patients and ~444 NC patients (860 total) +- Validation: 104 AD patients and ~112 NC patients (216 total) +''' + + +""" +Loads the ADNI dataset test images from the given local directory/path. In cases +where only a train and test set are created, this method will also be used to +load the training set. +Applies the specified transforms to this set. + +It is assumed that the ADNI dataset images are organised in this directory +structure, relative to the project: + - dataset_path/: + - 'test/' + - 'AD/ + - 'NC/' + - 'train/' + - 'AD/' + - 'NC/' +By default, dataset_path is set to: './recognition/TRANSFORMER_43909856/dataset/AD_NC'. +The PyTorch ImageFolder class automatically assigns class labels for each image +based on the subfolders in 'train' and 'test'. An image in an 'AD' dir is +assigned a class label of 'AD' (0) (Alzheimer's Detected), and an image in an 'NC' +dir is assigned a class label of 'NC' (1) (Normal Cognition). + +Params: + dataset_path (str): the directory containing the ADNI dataset images, structured + by the image classifications + tf (torch transform): the transform to be applied to the data + batch_size (int): the number of input images to be added to each DataLoader batch + dataset (str): "train" or "test" set + +Returns: + The given set's data +""" +def load_ADNI_data(dataset_path=DATASET_PATH, tf=TEST_TF, batch_size=BATCH_SIZE, + dataset="test"): + # Load the ADNI data + data = ImageFolder(root=osp.join(dataset_path, dataset), transform=tf) + + return data + + +""" +Sort a selection of images from an input bucket based on their filename (in +lexicographic order), so that images belonging to the same patient are grouped +together in batches. + +Implementation of this method assumes that all image filenames within the given +bucket are within the same directory locations, so that the image files can +be correctly sorted into lexicographic order. By sorting them by image file name, +the images are automatically sorted and grouped by patient ID (as the patient +ID is the first component of the image file names). + +Params: + bucket (torch object): a given 'bucket'/collection of images, with their + filenames included +Returns: + A sorted version of the bucket - entries are sorted by image filename, in + lexicographic order +""" +def patient_sort(bucket): + return sorted(bucket) + + +""" +Opens the PIL image specified by the given filename. Returns the opened PIL + +Params: + file_data (tuple(str, str)): a filename for the PIL image to be opened, and + label for the given data point associated with + that file ("AD" or "NC") +Returns: + Tuple containing the opened PIL image, and the label for the given + data point associated with that image ("AD" or "NC") +""" +def open_image(file_data): + filename, class_name = file_data + return Image.open(filename).convert("RGB"), class_name + + +""" +Determines the class label to be assigned to a given file, based on the +contents of its filename. Returns an assignment of the class label to the filename. + +Implementation assumes that the subdirs of the train dir separates datapoints of +different classes into different dirs (AD classes are in the "AD" subdir, and +NC classes are in the "NC" subdir). +Because of this, the method assumes that there must be one or more occurrences +of the particular class name ("AD" or "NC") in the given filename. + +Params: + filename (str): the file name of the given input image +Returns: + Tuple containing the given filename, and the class for that image + file ("AD" - 0 or "NC" - 1) + +Method throws an exception if the class label can't be determined (there are +no "AD" or "NC" substrings in the filename, indicating that the +"AD" and "NC" subdirs don't exist). +""" +def add_class_labels(filename): + split = filename.split("AD_NC") + if split[-1].find("AD") != -1: + # File is in the "AD" subdir + class_name = 0 + elif split[-1].find("NC") != -1: + # File is in the "NC" subdir + class_name = 1 + else: + # If the class can't be determined, throw an exception + return Exception(f"The class label for {split[-1]} is unknown.") + return filename, class_name + + +""" +Apply a transform to images in the training set. + +Params: + image_data (tuple(PIL image, str)): contains the opened PIL image, and + the class label for that image +Returns: + The transformed input image, and the class label for that image + (not transformed) +""" +def apply_train_tf(image_data, train_tf=TRAIN_TF): + image, class_name = image_data + return train_tf(image), class_name + + +""" +Apply a transform to images in the validation set. + +Params: + image_data (tuple(PIL image, str)): contains the opened PIL image, and + the class label for that image +Returns: + The transformed input image, and the class label for that image + (not transformed) +""" +def apply_val_tf(image_data, val_tf=VAL_TF): + image, class_name = image_data + return val_tf(image), class_name + + +""" +Loads the ADNI dataset train images from the given local directory/path. +Depending on the provided train_size param, a validation set may also be +generated from data in the 'train' subdir, using a stratified split. +To prevent data leakage, the train and validation set are created using a +patient-based split. All MRI image slices for each patient are grouped +together (per patient) - each patient is then shuffled and split into +training and validation sets. After the split is performed, the patient MRI +slices are then 'ungrouped', and data within the sets is then shuffled for each +individual image. +The method also applies the specified transforms to the train and/or validation set. + +Implementation of this method assumes that there are exactly 20 MRI image slices +per patient within the dataset. Additionally, it is assumed that there is no +data leakage between the pre-determined train and test sets (there is no patient +data within the training set, where that same patient has the same data or other +data of their own within the test set). + +Params: + dataset_path (str): the directory containing the ADNI dataset images, structured + by the image classifications + train_tf (torch transform): the transform to be applied to the training set data + val_tf (torch transform): the transform to be applied to the validation set data + batch_size (int): the number of input images to be added to each DataLoader batch + train_size (float): the size of data points that will be added to the + train set. If < 1, the remaining size will be + added to a validation set + (val_size = 1 - train_size). + Implementation assumes that this value is in the + range (0, 1]. + imgs_per_patient (int): the number of MRI slice images per patient which are + present in the dataset + +Returns: + Tuple with 3 values: + The train set data, and the number of training points in the + train set. If train_size < 1, the validation + set data is also returned; otherwise, a value of None is returned. +""" +def load_ADNI_data_per_patient(dataset_path=DATASET_PATH, train_tf=TRAIN_TF, + val_tf=VAL_TF, batch_size=BATCH_SIZE, train_size=0.8, + imgs_per_patient=N_IMGS_PER_PATIENT): + if train_size >= 1: + ''' + If train_size >= 1, create only a training set. + Load the data in the same manner used to load the ADNI test set. + ''' + train_images = load_ADNI_data(dataset_path=dataset_path, tf=train_tf, + batch_size=batch_size, dataset="train") + # Set the validation set DataLoader to none (no validation set used) + return train_images, len(list(train_images)), None + + ''' + Create a training and validation set: + Get all jpeg files in the train set subdirectories, then label the data + (with the AD or NC classes). + ''' + AD_files = FileLister(root=osp.join(dataset_path, "train", "AD"), + masks="*.jpeg", recursive=False).map( + add_class_labels) + NC_files = FileLister(root=osp.join(dataset_path, "train", "NC"), + masks="*.jpeg", recursive=False).map( + add_class_labels) + + ''' + Add the data into distinct batches, grouped by patient ID + (the batches contain the 20 MRI images per patient in the dataset). + Performs a buffer shuffle, which shuffles the batches corresponding to each + patient within the entire bucket (but doesn't shuffle the 20 images + within each patient's batch). + ''' + AD_batch = AD_files.bucketbatch(use_in_batch_shuffle=False, + batch_size=N_IMGS_PER_PATIENT, sort_key=patient_sort) + NC_batch = NC_files.bucketbatch(use_in_batch_shuffle=False, + batch_size=N_IMGS_PER_PATIENT, sort_key=patient_sort) + + ''' + Perform a stratified split of AD and NC images by the train_size argument. + Note that the data has previously been shuffled by patient ID, within each + of the two classes. + ''' + val_size = 1 - train_size + AD_train, AD_val = AD_batch.random_split(weights={"train": train_size, + "validation": val_size}, + total_length=len(list(AD_batch)), + seed=2) + NC_train, NC_val = NC_batch.random_split(weights={"train": train_size, + "validation": val_size}, + total_length=len(list(NC_batch)), + seed=3) + + ''' + Combine the AD and NC class splits into combined train and validation sets. + Once combined, unbatch the data (so that data images are no longer batched + by patient). + Then, shuffle all data images so that the entirety of a patient's + data is not placed together (in one particular section of the dataset). + ''' + train_data = AD_train.concat(NC_train).unbatch().shuffle() + val_data = AD_val.concat(NC_val).unbatch().shuffle() + # Get the number of training set data points: + n_train_datapoints = len(list(train_data)) + + + ''' + Apply a sharding filter to the data after shuffling has taken place. + Open the PIL images from the given dataset filenames. + Once opened, apply the specified train and validation transforms to the images. + ''' + train_images = train_data.sharding_filter().map(open_image).map(apply_train_tf) + val_images = val_data.sharding_filter().map(open_image).map(apply_val_tf) + + return train_images, n_train_datapoints, val_images + + +""" +Plots a 4x4 grid of sample images from a specified split data set (train, +validation, or test) within the ADNI dataset. + +Params: + loader (torch DataLoader): a DataLoader for the given train, test, or validation + set, which contains randomly shuffled MRI image slices + show_plot (bool): show the plot in a popup window if True; otherwise, don't + show the plot + save_plot (bool): save the plot as a PNG file to the directory "plots" if + True; otherwise, don't save the plot +""" +def plot_data_sample(loader, show_plot=False, save_plot=False): + ### Set-up GPU device #### + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if not torch.cuda.is_available(): + print("Warning: CUDA not found. Using CPU") + else: + print(torch.cuda.get_device_name(0)) + + # Get the size of the set: + #print(f"Data points: {len(loader.dataset)}") + + # Plot a selection of images from a single batch of the dataset + sample_data = next(iter(loader)) + # Create a grid of 4x4 images + plt.figure(figsize=(4,4)) + plt.axis("off") + # Add a title + plt.title("Sample of ADNI dataset MRI images") + # Plot the first 16 images in the batch + plt.imshow(np.transpose(make_grid(sample_data[0].to(device)[:16], padding=2, + normalize=True).cpu(),(1, 2, 0))) + + if save_plot: + # Create an output folder for the plot, if one doesn't already exist + directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'plots') + if not os.path.exists(directory): + os.makedirs(directory) + # Save the plot in the "plots" directory + plt.savefig(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'plots', + "ADNI_sample_data.png"), dpi=600) + + if show_plot: + # Show the plot + plt.show() + + +""" +Main method - make sure to run any methods in this file within here. +Adding this so that multiprocessing runs appropriately/correctly +on Windows devices. +""" +def main(): + pass + +if __name__ == '__main__': + main() + diff --git a/recognition/TRANSFORMER_43909856/dataset/README.md b/recognition/TRANSFORMER_43909856/dataset/README.md new file mode 100644 index 000000000..01780015f --- /dev/null +++ b/recognition/TRANSFORMER_43909856/dataset/README.md @@ -0,0 +1,22 @@ +# dataset: + +The ADNI dataset should be placed in this directory. + +To ensure that the model and dataset loading works correctly, the dataset should +be added such that it conforms to the following directory structure: + +``` + - dataset/: + - 'test/' + - 'AD/ + - ... + - 'NC/' + - ... + - 'train/' + - 'AD/' + - ... + - 'NC/' + - ... +``` + +'...' represents the ADNI image JPEG files. \ No newline at end of file diff --git a/recognition/TRANSFORMER_43909856/models/README.md b/recognition/TRANSFORMER_43909856/models/README.md new file mode 100644 index 000000000..a661a534c --- /dev/null +++ b/recognition/TRANSFORMER_43909856/models/README.md @@ -0,0 +1,9 @@ +# models + +After model training has completed, the trained model will be saved to this directory location. + +This location will also contain saved metrics for the train, validation, and test set performance. + +These metrics include the train and validation loss throughout the training process, +the validation set accuracy, and the predictions and observed classes for the validation +and test set. \ No newline at end of file diff --git a/recognition/TRANSFORMER_43909856/modules.py b/recognition/TRANSFORMER_43909856/modules.py new file mode 100644 index 000000000..b815d7aca --- /dev/null +++ b/recognition/TRANSFORMER_43909856/modules.py @@ -0,0 +1,411 @@ +import os +import os.path as osp +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange + + +""" +This file contains all of the components required to create a 2D image recognition +transformer (ViT) used for a binary classification problem. +ViT for ImageNet: https://arxiv.org/abs/2205.01580 +https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9880094 + +- ViT has a set of tokens + 1 class token. In this model, average pooling of the +model will be used instead of a class token +- Uses an average pooling layer at the end of all convolution layers. +Generates a 1 dimensional string used for outputting classification of images +- Image transformers don't use a masked multi-head attention component, +as they don't need to be auto-regressive (look at both information from both the +past and the future of the current position in the input). This transformer is +bi-directional. + + +Possible hyperparams: +- How big the MLP depth is +- No. of attention heads +- Width of the network +- How many attention layers + +This model is using the S/16 configuration from the ViT paper above. + + +Einops: +- Use to perform dot products on particular indices of passed tensors +using Einstein summation +- Use to transform from 1D to 2D, patchifying, etc. + + +A rough diagram of required components: + +Main sub-component #1: +-> +-> Multi-Head Attention -> Add & Layer Norm -> Feed Forward --> Add & Layer Norm -> +-> ^ | ^ +---------------------------------| |---------| + + +Main sub-component #2: +-> +-> Multi-Head Attention ------> Add & Layer Norm -> Feed Forward --> Add & Layer Norm -> +-> ^ ^ | ^ | ^ + | | |---------| |---------| + |-| + | + +Full ViT: +Inputs -> Input Embedding -> Positional Encoding -> Main sub-component #1 -> +Outputs (shifted right) -> Output Embedding -> Positional Encoding --------> Main sub-component #2 -> Linear -> Softmax -> Output probabilities + +""" + + +""" +Creates the multi-head attention modules used within the ViT network. +The component of the network taking inputs will need N multi-head attention modules. +The component of the network taking outputs (also connected to the previous component) +will also need N multi-head attention modules. + +Also includes the components for calculating the scaled dot product attention +from the input Keys, Queries, and Values. + +Scaled Dot-Product Attention: +Q -> + MatMul -> Scale -> Mask (optional) -> SoftMax -> Matmul -> +K -> ^ +V ----------------------------------------------------| + +Multi-Head Attention: +V -> Linear -> +K -> Linear -> Scaled Dot-Product Attention -> Concat -> Linear -> +Q -> Linear -> + +A masked multi-head attention could be created by adding a mask layer to the +scaled dot product attention component, but a mask layer will not be used for +this model. +""" +class Attention(nn.Module): + + """ + Create/initialise a multi-head attention module, using self-attention. + + Params: + dimensions (int): dimensions/size of the input data + n_heads (int): the number of heads added to each multi-head attention component + head_dimensions (int): the dimensions/size of each head added to the attention. + """ + def __init__(self, dimensions, n_heads=8, head_dimensions=64): + super().__init__() + + self.n_heads = n_heads + # Normalise the matrix, using the square root of the size of the head + self.scale = head_dimensions ** -0.5 + # All operations will be normalised (layer norm for 1D representations, similar to batch norm) + self.layer_norm = nn.LayerNorm(dimensions) + # Used for performing network concatenations + inner_dimensions = head_dimensions * n_heads + + # Softmax layer for each scaled dot product attention (applied before matmul out) + self.attend = nn.Softmax(dim=-1) + # Converts every token to a Query, Key, or Value + self.to_qkv = nn.Linear(dimensions, inner_dimensions * 3, bias=False) + # After concatenating the scaled dot product attention, concatenate this into a linear layer + self.to_out = nn.Linear(inner_dimensions, dimensions, bias=False) + + + """ + Perform one forward pass step (forward propagation) to train the attention + module. Create Keys, Queries, and Values from the input data. + + Params: + x: 1D representation of input data (usually in the form of tokens) + Returns: + Computed result of attention module (after Linear flattening layer is applied) + """ + def forward(self, x): + # Normalise the input data + x = self.layer_norm(x) + ''' + Convert the input into Keys, Queries, and Values. + If using cross-attention, QKV would be split between x and y, where y + is another set of tokens. + ''' + qkv = self.to_qkv(x).chunk(3, dim=-1) + ''' + Convert the KQV tensors into groups, then split them + across the attention heads. + b - dimensions/size of each batch + n - number of batches + h - number of heads + d - dimensions/size of each head + ''' + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.n_heads), qkv) + + # Get correlations - perform matrix multiplication between Q and K. Scale result to [0, 1] + q_k_correlations = torch.mul(q, k) * self.scale + # Turn correlations into probabilites using softmax function + attention = self.attend(q_k_correlations) + # Multiply attention probabilites with the Values + out = torch.mul(attention, v) + # Concatenate results with # of heads and the dimensions of each head + out = rearrange(out, "b h n d -> b n (h d)") + # Apply to Linear layer to give flattened linear output + return self.to_out(out) + + +""" +A simple feed-forward NN, used within components of the ViT. + +Network contains two hidden Linear layers, with an activation function +between them. Layer normalisation (for 1D data) is also applied to input values. + +FeedForward modules are added to the network after multi-head attention modules, +taking the flattened Linear layer outputs from these modules. +One FeedForward module will be placed after the multi-head attention module handling +inputs. A second module will be placed after the chained masked multi-head attention +and multi-head attention modules handling outputs. +As there are N complete network components handling inputs and N complete network +components handling outputs, this means that N FeedForward modules are required +for inputs and N FeedForward modules are required for outputs. +""" +class FeedForward(nn.Module): + + """ + Create the simple linear layers used within the ViT. + + Params: + dimensions (int): the size/dimensions of the input data + hidden_layer_dimensions (int): the size/dimensions of the two hidden + Linear layers + """ + def __init__(self, dimensions, hidden_layer_dimensions): + super().__init__() + # Create the network: + self.network = nn.Sequential( + # Apply 1D layer normalisation (similar to batch norm for 2D data) + nn.LayerNorm(dimensions), + # Add the first hidden Linear layer + nn.Linear(dimensions, hidden_layer_dimensions), + # Apply GELU (Gaussian Error Linear Unit) activation fn + nn.GELU(), + # Add second hidden Linear layer + nn.Linear(hidden_layer_dimensions, dimensions) + ) + + """ + Perform one forward pass (forward propagation) on the Feed-Forward NN. + + Params: + x: 1D representation of input data + Returns: + Computed result of attention module (after second Linear layer is applied) + """ + def forward(self, x): + return self.network(x) + + +""" +Create the whole Transformer (ViT) network, using combinations of the Attention +and FeedForward modules. +""" +class Transformer(nn.Module): + + """ + Create the layers required for the Transformer network. + Add Attention modules, whose outputs are fed into FeedForward modules. + + Params: + dimensions (int): the size/dimensions of the input data + depth (int): the depth of the network (number of required Attention modules, + whose outputs are chained into FeedForward modules) + n_heads (int): the number of heads added to each multi-head attention component + head_dimensions (int): the dimensions/size of each head added to the attention. + mlp_dimensions (int): the size/dimensions of the two hidden Linear layers + in the Feed-Forward components of the Transformer + """ + def __init__(self, dimensions, depth, n_heads, head_dimensions, mlp_dimensions): + super().__init__() + # All operations will be normalised (layer norm for 1D representations, similar to batch norm) + self.layer_norm = nn.LayerNorm(dimensions) + + # Add the # of required chained Attention and FeedForward modules + self.layers = nn.ModuleList([]) + for i in range (depth): + self.layers.append(nn.ModuleList([ + Attention(dimensions=dimensions, n_heads=n_heads, + head_dimensions=head_dimensions), + FeedForward(dimensions=dimensions, hidden_layer_dimensions=mlp_dimensions) + ])) + + + """ + Perform one forward pass (forward propagation) on the Transformer network. + Residual connections are maintained between the input to that sub-component + and the current Attention and FeedForward layers - the residual connection is + added to the output, then normalised. + + Params: + x: 1D representation of input data + Returns: + Computed result of final Attention module (after second Linear layer of + final FeedForward module is applied) + """ + def forward(self, x): + for attention, feed_forward in self.layers: + # Add residual connections between the input to that sub-component and the modules + x = attention(x) + x + x = feed_forward(x) + x + # Normalise the output + return self.layer_norm(x) + + +""" +Creates a positional encoding for the Transformer input data, using a 2D set of +sinusoids. +Every row of the encoding will vary with frequency, allowing for the inputs to +be encoded uniquely and their position located. + +Params: + height (int): the required height of the positional encoding + width (int): the required width of the positional encoding + dimensions (int): the dimensions/size of the input features. This value + must be a multiple of 4. + temperature (int): determines the frequencies used by the sinusoids in the + positional encoding + dtype (torch dtype): the data type for the positional encoding to be stored as + +Returns: + The computed positional encoding +""" +def create_positonal_encodings(height, width, dimensions, temperature=10000, + dtype=torch.float32): + # Set up a 2D set of coordinates in a mesh grid + y, x = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + # Set the frequencies used by the sinusoids in the positional encoding + omega = torch.arange(dimensions // 4) / (dimensions // 4 - 1) + omega = 1.0 / (temperature ** omega) + + # Flatten the x and y coordinates into 1D arrays + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + + # Compute sinusoids and combine them together + positional_encoding = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return positional_encoding.type(dtype) + + +""" +Create a simple ViT as mentioned in this paper: https://arxiv.org/abs/2205.01580 +The created model will be used in a classification problem. +Using model S/16 (width=384, depth=12, mlp_head_size=1536, n_heads=6) +""" +class SimpleViT(nn.Module): + """ + Initialise/create a simple ViT model. + Breaks each image up into smaller sized 'patches', which are used as input + tokens. + + Params: + image_size (tuple(int, int)): the size/dimensions of the input image + (height x width) + patch_size (tuple(int, int)): the size/dimensions of the image patches + (height x width). The image height should + be a multiple of the patch height, and + the image width should be a multiple of + the patch width. + n_classes (int): the number of classes in the classification problem + dimensions (int): the size/dimensions of the input data + depth (int): the depth of the network (number of required Attention modules, + whose outputs are chained into FeedForward modules) + n_heads (int): the number of heads added to each multi-head attention component + mlp_dimensions (int): the size/dimensions of the two hidden Linear layers + in the Feed-Forward components of the Transformer + n_channels (int): the number of channels in the input image (3 for RGB) + head_dimensions (int): the dimensions/size of each head added to the attention. + """ + def __init__(self, *, image_size, patch_size, n_classes, dimensions, depth, + n_heads, mlp_dimensions, n_channels=3, head_dimensions=64): + super().__init__() + ''' + The image height should be a multiple of the patch height, and + the image width should be a multiple of the patch width. + ''' + image_height, image_width = image_size + patch_height, patch_width = patch_size + + # Get the dimensions of each patch + patch_dimensions = n_channels * patch_height * patch_width + + ''' + Turn all images into multiple patches ('patchifying'), of the size + (patch height x patch width x num channels). THe patches are 1D tokens. + h - height of image + w - width of image + p1 - patch height + p2 - patch width + ''' + self.to_patch_embedding = nn.Sequential( + Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_height, p2=patch_width), + # Add a layer norm for 1D dat + nn.LayerNorm(patch_dimensions), + # Embed patches in linear layer + nn.Linear(patch_dimensions, dimensions), + # Add layer norm after the linear layer + nn.LayerNorm(dimensions), + ) + + # Create the positional embedding (scale image dimensions by the patch size) + self.positional_embedding = create_positonal_encodings( + height=(image_height // patch_height), + width=(image_width // patch_width), + dimensions=dimensions + ) + + # Add the Transformer network + self.transformer = Transformer(dimensions, depth, n_heads, head_dimensions, + mlp_dimensions) + + # Use average pooling for the network (instead of using a class token) + self.pooling = "mean" + + # Store the identity to perform skip connections + self.to_latent = nn.Identity() + + # Linear layer outputs the model's classifications + self.linear_head = nn.Linear(dimensions, n_classes) + + + """ + Perform a forward pass (forward propagation) of the model. + + Creates a patch embedding of the image (converting it to a 1D token), + then encodes the patch's position. The model is then trained. + + Params: + image: the input image for the model to be trained on + + Returns: + The inear output layer (tcontains the model's classifications in + a one-hot encoding) + """ + def forward(self, image): + # Get the CUDA hardware device + device = image.device + + # Get the patch embedding of the image + x = self.to_patch_embedding(image) + # Get the positonal embedding of the image, send this embedding to the GPU + x += self.positional_embedding.to(device, dtype=x.dtype) + + # Apply the Transformer network to the model + x = self.transformer(x) + # Perform average pooling on the Transformer network + x = x.mean(dim=1) + + # Apply a skip connection to the model + x = self.to_latent(x) + # Output the model's classifications + return self.linear_head(x) + + diff --git a/recognition/TRANSFORMER_43909856/plots/ADNI_sample_data.png b/recognition/TRANSFORMER_43909856/plots/ADNI_sample_data.png new file mode 100644 index 000000000..932803575 Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ADNI_sample_data.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/Dosovitskiy_etal_2021_ViT_diagram.PNG b/recognition/TRANSFORMER_43909856/plots/Dosovitskiy_etal_2021_ViT_diagram.PNG new file mode 100644 index 000000000..3641af9be Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/Dosovitskiy_etal_2021_ViT_diagram.PNG differ diff --git a/recognition/TRANSFORMER_43909856/plots/README.md b/recognition/TRANSFORMER_43909856/plots/README.md new file mode 100644 index 000000000..f6f6e39d1 --- /dev/null +++ b/recognition/TRANSFORMER_43909856/plots/README.md @@ -0,0 +1,10 @@ +# plots + +When plots/graphs and input image previews are generated, they are saved to this +directory location. + +I have included some examples of the plots that can be made. + +These plots are useful +for visualising the data, and for viewing model metrics/performance during training or +validation stages. \ No newline at end of file diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_loss.png b/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_loss.png new file mode 100644 index 000000000..9746476a4 Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_loss.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_test_confusion_matrix.png b/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_test_confusion_matrix.png new file mode 100644 index 000000000..0d7802ab3 Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_test_confusion_matrix.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_validation_accuracy.png b/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_validation_accuracy.png new file mode 100644 index 000000000..3a2ac1153 Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_40epochs_validation_accuracy.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_loss.png b/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_loss.png new file mode 100644 index 000000000..4af6ca8dd Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_loss.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_test_confusion_matrix.png b/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_test_confusion_matrix.png new file mode 100644 index 000000000..aa2700670 Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_test_confusion_matrix.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_validation_accuracy.png b/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_validation_accuracy.png new file mode 100644 index 000000000..52bc878a5 Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_90epochs_validation_accuracy.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_loss.png b/recognition/TRANSFORMER_43909856/plots/ViT_loss.png new file mode 100644 index 000000000..be5e6720e Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_loss.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_test_confusion_matrix.png b/recognition/TRANSFORMER_43909856/plots/ViT_test_confusion_matrix.png new file mode 100644 index 000000000..f7c72bf3e Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_test_confusion_matrix.png differ diff --git a/recognition/TRANSFORMER_43909856/plots/ViT_validation_accuracy.png b/recognition/TRANSFORMER_43909856/plots/ViT_validation_accuracy.png new file mode 100644 index 000000000..06b8e5ad0 Binary files /dev/null and b/recognition/TRANSFORMER_43909856/plots/ViT_validation_accuracy.png differ diff --git a/recognition/TRANSFORMER_43909856/predict.py b/recognition/TRANSFORMER_43909856/predict.py new file mode 100644 index 000000000..24065e406 --- /dev/null +++ b/recognition/TRANSFORMER_43909856/predict.py @@ -0,0 +1,191 @@ +import os +import os.path as osp +import torch +import torch.nn as nn +import time +import numpy as np +from torch.utils.data import DataLoader +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, PrecisionRecallDisplay +import matplotlib.pyplot as plt + +import dataset +import modules + +""" +This file is used to test the ViT model trained on the ADNI dataset. +Any results will be printed out, and visualisations will be provided +where applicable. +""" + +#### Set-up GPU device #### +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if not torch.cuda.is_available(): + print("Warning: CUDA not found. Using CPU") +else: + print(torch.cuda.get_device_name(0)) + + +#### Model hyperparameters: #### +BATCH_SIZE = 32 +N_CLASSES = 2 +# Dimensions to resize the original 256x240 images to (IMG_SIZE x IMG_SIZE) +IMG_SIZE = 224 + + +#### File paths: #### +# Local dataset path +# DATASET_PATH = osp.join("recognition", "TRANSFORMER_43909856", "dataset", "AD_NC") +# Path to dataset on Rangpur HPC +DATASET_PATH = osp.join("/", "home", "groups", "comp3710", "ADNI", "AD_NC") +OUTPUT_PATH = osp.join("recognition", "TRANSFORMER_43909856", "models") + + +""" +Loads the ADNI dataset's test set. +Loads the previously trained ViT classification model, then tests the model +on the test set. + +Params: + model_filename (str): The file path and file name for the model to be evaluated + save_metrics (bool): If true, saves separate lists of the model's + predicted values and the corresponding observed/empirical + values for each image in the test set (to CSV files). + Otherwise, does not save these values +""" +def test_model(model_filename=osp.join(OUTPUT_PATH, "ViT_ADNI_model.pt"), save_metrics=True): + # Get the testing data (ADNI) + test_data = dataset.load_ADNI_data(dataset_path=DATASET_PATH) + test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False) + + # Initalise a blank slate model + model = modules.SimpleViT(image_size=(IMG_SIZE, IMG_SIZE), patch_size=(16, 16), n_classes=N_CLASSES, + dimensions=384, depth=12, n_heads=6, mlp_dimensions=1536, n_channels=3) + # Move the model to the GPU device + model = model.to(device) + # Load the pre-trained model into the blank slate ViT + model.load_state_dict(torch.load(model_filename, map_location=device)) + + + # Test the model: + print("Testing has started") + # Get a timestamp for when the model testing starts + start_time = time.time() + + # Store the model's predicted classes and the observed/empirical classes + predictions = [] + observed = [] + + model.eval() + with torch.no_grad(): + # Keep track of the total number predictions vs. correct predictions + correct = 0 + total = 0 + for images, labels in test_loader: + images = images.to(device) + labels = labels.to(device) + # Add images to the data and get the predicted classes + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + # Add to the total # of predictions + total += labels.size(0) + # Add correct predictions to a total + correct += (predicted == labels).sum().item() + + # Save the predictions and the observed/empirical class labels + predictions += predicted.cpu() + observed += labels.cpu() + + # Get the amount of time that the model spent testing + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Test accuracy: {round((100 * correct) / total, 5)}%") + print(f"Testing finished. Testing took {round(elapsed_time, 2)} seconds " + +f"({round(elapsed_time/60, 4)} minutes)") + + # Save testing metrics + if save_metrics: + # Create a dir for saving the testing metrics (if one doesn't exist) + if not osp.isdir(OUTPUT_PATH): + os.makedirs(OUTPUT_PATH) + + # Save the model's predictions + np.savetxt(osp.join(OUTPUT_PATH, 'ADNI_test_predictions.csv'), + np.asarray(predictions)) + # Save the observed/empirical values for the test set + np.savetxt(osp.join(OUTPUT_PATH, 'ADNI_test_observed.csv'), + np.asarray(observed)) + + +""" +Load the predicted classes or the empirical/observed classes, for each image in +the test set. These values are loaded from CSV files, which are saved during the +testing process. + +Params: + filename (str): the name of the CSV file to load +Returns: + An array, in which each entry is a classification label (either predicted + or actual), as predicted by the dataset. Labels are either a 0 (AD) or 1 (NC) +""" +def load_test_labels(filename=osp.join(OUTPUT_PATH, 'ADNI_test_predictions.csv')): + # Load the file + labels = np.loadtxt(filename, dtype=float) + # Convert from a numpy array to a python base lib list + return labels.tolist() + + +""" +Plot a confusion matrix, based on the predicted classes and empirical/observed +classes for the test set data. + +Params: + predicted (array[int]): predicted class values for each image in the test set. + Values are either 0 (AD) or 1 (NC) + observed (array[int]): empirical/observed class values for each image in the + test set. Values are either 0 (AD) or 1 (NC) + show_plot (bool): show the plot in a popup window if True; otherwise, don't + show the plot + save_plot (bool): save the plot as a PNG file to the directory "plots" if + True; otherwise, don't save the plot +""" +def plot_confusion_matrix(predicted, observed, show_plot=False, save_plot=False): + cm = confusion_matrix(observed, predicted) + # Create a graph/plot of the confusion matrix + cm_display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["AD", "NC"]) + cm_display.plot() + # Add a title + plt.title("ViT Transformer (ADNI classifier) test set confusion matrix") + + # Save the plot + if save_plot: + # Create an output folder for the plot, if one doesn't already exist + directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'plots') + if not os.path.exists(directory): + os.makedirs(directory) + # Save the plot in the "plots" directory + plt.savefig(os.path.join(directory, "ViT_test_confusion_matrix.png"), dpi=600) + if show_plot: + # Show the plot + plt.show() + + +""" +Main method - make sure to run any methods in this file within here. +Adding this so that multiprocessing runs appropriately/correctly +on Windows devices. +""" +def main(): + # Test the model + test_model() + + # Load predicted class labels + # predicted = load_test_labels() + # # Load empirical/observed class labels + # observed = load_test_labels(osp.join(OUTPUT_PATH, 'ADNI_test_observed.csv')) + + # # Plot a confusion matrix + # plot_confusion_matrix(predicted, observed, show_plot=True, save_plot=True) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/recognition/TRANSFORMER_43909856/train.py b/recognition/TRANSFORMER_43909856/train.py new file mode 100644 index 000000000..fdd9eba71 --- /dev/null +++ b/recognition/TRANSFORMER_43909856/train.py @@ -0,0 +1,360 @@ +import os +import os.path as osp +import torch +import torch.nn as nn +import time +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from torch.utils.data.backward_compatibility import worker_init_fn + +import dataset +import modules + + +""" +This file contains code for training, validating, testing and saving the model. +The ViT model is imported from modules.py, and the data loader +is imported from dataset.py. +The losses and metrics will be plotted during training. +""" + +#### Set-up GPU device #### +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if not torch.cuda.is_available(): + print("Warning: CUDA not found. Using CPU") +else: + print(torch.cuda.get_device_name(0)) + + +#### Model hyperparameters: #### +N_EPOCHS = 30 +LEARNING_RATE = 0.001 +N_CLASSES = 2 +# Dimensions to resize the original 256x240 images to (IMG_SIZE x IMG_SIZE) +IMG_SIZE = 224 +# The batch size used by the data loaders for the train, validation, and test sets +BATCH_SIZE = 32 + + +#### File paths: #### +# Local dataset path +# DATASET_PATH = osp.join("recognition", "TRANSFORMER_43909856", "dataset", "AD_NC") +# Path to dataset on Rangpur HPC +DATASET_PATH = osp.join("/", "home", "groups", "comp3710", "ADNI", "AD_NC") +OUTPUT_PATH = osp.join("recognition", "TRANSFORMER_43909856", "models") + + +""" +Loads the ADNI dataset into train (and possibly validation) sets. +Initialises the model, then trains the model. + +If a validation set is created, then the model performance will also +be evaluated at the end of every training epoch on the validation set data. +The validation set is effectively used for hyperparameter tuning, where the +hyperparameter being observed is the number of training epochs. + +Params: + save_model_data (bool): if true, saves the model as a .pt file and model + training/validation metrics as .csv files. If false, + doesn't save the model or training metrics +""" +def train_model(save_model_data=True): + # Get the training and validation data (ADNI) and # of total steps + train_images, n_training_points, val_images = \ + dataset.load_ADNI_data_per_patient(dataset_path=DATASET_PATH, train_size=0.8) + # Get the total step (# of batches) + total_step = int(np.ceil(n_training_points / BATCH_SIZE)) + + # Add the training and validation data to data loaders + train_loader = DataLoader(train_images, batch_size=BATCH_SIZE, shuffle=True, + num_workers=0, worker_init_fn=worker_init_fn) + + if val_images is not None: + # If val_images is None, don't create a validation set + val_loader = DataLoader(val_images, batch_size=BATCH_SIZE, shuffle=True, + num_workers=0, worker_init_fn=worker_init_fn) + + # Initalise the model + model = modules.SimpleViT(image_size=(IMG_SIZE, IMG_SIZE), patch_size=(16, 16), n_classes=N_CLASSES, + dimensions=384, depth=12, n_heads=6, mlp_dimensions=1536, n_channels=3) + # Move the model to the GPU device + model = model.to(device) + + # Use binary cross-entropy as the loss function + criterion = nn.CrossEntropyLoss() + + # Use the Adam optimiser for ViT + # optimiser = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4) + optimiser = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) + # Use a piecewise linear LR scheduler + scheduler = torch.optim.lr_scheduler.OneCycleLR(optimiser, max_lr=LEARNING_RATE, + steps_per_epoch=total_step, epochs=N_EPOCHS) + # TODO ViT paper uses a different kind of LR scheduler - may want to try this + + # Store the epoch, step, & train loss value for the model at various steps + train_loss_values = [] + # Store the epoch, validation loss, and validation set accuracy at each epoch + val_loss_values = [] + + # Store the model's predicted classes and the observed/empirical classes on the validation set + predictions = [] + observed = [] + + # Train the model: + model.train() + print("Training has started") + # Get a timestamp for when the model training starts + start_time = time.time() + + # Train the model for the given number of epochs + for epoch in range(N_EPOCHS): + + # Train on each image in the training set + for i, (images, labels) in enumerate(train_loader): + images = images.to(device) + labels = torch.Tensor(labels).to(device) + + # Perform a forward pass of the model + outputs = model(images) + # Calculate the training loss + loss = criterion(outputs, labels) + + # Perform backpropagation + optimiser.zero_grad() + loss.backward() + optimiser.step() + + # Print/log the training metrics for every 100 steps, and at the end of each epoch + if (i+1) % 100 == 0 or i+1 == total_step: + print(f"Epoch [{epoch+1}/{N_EPOCHS}] Step [{i+1}/{total_step}] " + + f"Training loss: {round(loss.item(), 5)}") + train_loss_values += [[epoch+1, i+1, total_step, round(loss.item(), 5)]] + + # Evaluate model on validation set (if a validation set exists): + if val_images is not None: + # Keep track of the total number of predictions vs. correct predictions + correct = 0 + total = 0 + + # After training has completed for each epoch, test model performance on validation data + for j, (val_images, val_labels) in enumerate(val_loader): + val_images = val_images.to(device) + val_labels = torch.Tensor(val_labels).to(device) + + # Get predictions on the validation data from the model + val_outputs = model(val_images) + _, predicted = torch.max(val_outputs.data, 1) + + # Save predictions and observed/empirical class labels + predictions += predicted.cpu() + observed += val_labels.cpu() + + # Add to the total # of predictions + total += val_labels.size(0) + # Add correct predictions to a total + correct += (predicted == val_labels).sum().item() + + # Get the validation loss after all predictions have been made + val_loss = criterion(val_outputs, val_labels) + # Print/save metrics for the end of the epoch + print(f"End of epoch [{epoch+1}/{N_EPOCHS}] Validation loss: " + + f"{round(val_loss.item(), 5)} Validation accuracy: " + + f"{round((100 * correct) / total, 5)}%") + val_loss_values += [[epoch+1, round(val_loss.item(), 5), + round((100 * correct) / total, 5)]] + + # Increment the LR scheduler to change the learning rate after each epoch completes + scheduler.step() + + # Get the amount of time that the model spent training + end_time = time.time() + elapsed_time = end_time - start_time + + print(f"Training finished. Training took {round(elapsed_time, 2)} seconds " + + f"({round(elapsed_time/60, 4)} minutes)") + + if save_model_data: + # Create a dir for saving the trained model (if one doesn't exist) + if not osp.isdir(OUTPUT_PATH): + os.makedirs(OUTPUT_PATH) + + # Save the model + torch.save(model.state_dict(), osp.join(OUTPUT_PATH, "ViT_ADNI_model.pt")) + + # Save the training loss values + np.savetxt(osp.join(OUTPUT_PATH, 'ADNI_train_loss.csv'), + np.asarray(train_loss_values)) + + # Save validation metrics (if a validation set was used) + if val_images is not None: + # Save the validation loss values + np.savetxt(osp.join(OUTPUT_PATH, 'ADNI_val_loss.csv'), + np.asarray(val_loss_values)) + + # Save the model's predictions on the validation set + np.savetxt(osp.join(OUTPUT_PATH, 'ADNI_val_predictions.csv'), + np.asarray(predictions)) + + # Save the observed/empirical values for the validation set + np.savetxt(osp.join(OUTPUT_PATH, 'ADNI_val_observed.csv'), + np.asarray(observed)) + + +""" +Plot the change in training loss (binary cross-entropy) over the epochs. +Training loss is reported/updated every 100 training steps, and for the final +step in each training epoch. +If a validation set was used, change in validation loss at the end of each +epoch will also be plotted. + +Params: + train_loss_values (array[[int, int, int, float]]): each entry of the array + contains the current epoch, the current step number, + the total number of steps for this epoch, and the training + set loss recorded at this point. + val_loss_values (array[[int, float, float]]) or None: if this arg is + None, then validation set metrics won't be plotted. + If an array is passed, each entry of the array contains + the current epoch, the validation loss, and the validation + set accuracy recorded at this point. + show_plot (bool): show the plot in a popup window if True; otherwise, don't + show the plot + save_plot (bool): save the plot as a PNG file to the directory "plots" if + True; otherwise, don't save the plot +""" +def plot_loss(train_loss_values, val_loss_values=None, show_plot=False, + save_plot=False): + # Get the train losses + train_loss = [train_loss_values[i][3] for i in range(len(train_loss_values))] + + # Approximate the location of each train loss value within each epoch, using the step counts + current_step = np.array([train_loss_values[i][1] for i in range(len(train_loss_values))]) + total_steps = np.array([train_loss_values[i][2] for i in range(len(train_loss_values))]) + step_position_to_epoch = np.divide(current_step, total_steps) + + # Add these within-epoch estimations to the epoch numbers + epoch = np.array([train_loss_values[i][0] for i in range(len(train_loss_values))]) + epoch_estimation = np.add(step_position_to_epoch, epoch) + + # Set the figure size + plt.figure(figsize=(10,5)) + # Add a title + plt.title("ViT Transformer (ADNI classifier) model loss") + + # Plot the train loss + plt.plot(epoch_estimation, train_loss, label="Training set", color="Blue") + + # Plot the validation loss on the same graph (if required) + if val_loss_values is not None: + # Get the validation losses + val_loss = [val_loss_values[i][1] for i in range(len(val_loss_values))] + # Get the validation epochs + val_epoch = [val_loss_values[i][0] for i in range(len(val_loss_values))] + plt.plot(val_epoch, val_loss, label="Validation set", color="Green") + + # Add axes titles and a legend + plt.xlabel("Number of epochs") + plt.ylabel("Loss (binary cross-entropy)") + plt.legend() + + # Save the plot + # do I look like I know hwat a JPEG is + if save_plot: + # Create an output folder for the plot, if one doesn't already exist + directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'plots') + if not os.path.exists(directory): + os.makedirs(directory) + # Save the plot in the "plots" directory + plt.savefig(os.path.join(directory, "ViT_loss.png"), dpi=600) + + if show_plot: + # Show the plot + plt.show() + + +""" +Loads the training or validation set loss data, which is saved to a CSV file +during the training process. + +Params: + filename (str): the name of the CSV file to load +Returns: + An array of arrays. Each inner array contains the current epoch, the + current step, the total number of steps in the current epoch, and the + training loss at this point. +""" +def load_training_metrics(filename=osp.join(OUTPUT_PATH, 'ADNI_train_loss.csv')): + # Load the file + loss_values = np.loadtxt(filename, dtype=float) + # Convert from a numpy array to a python base lib list + return loss_values.tolist() + + +""" +Plot the change in accuracy of the validation set over the epochs. + +Params: + val_loss_values (array[[int, float, float]]) or None: each entry of the + array contains the current epoch, the validation loss, + and the validation set accuracy recorded at this point. + show_plot (bool): show the plot in a popup window if True; otherwise, don't + show the plot + save_plot (bool): save the plot as a PNG file to the directory "plots" if + True; otherwise, don't save the plot +""" +def plot_val_accuracy(val_loss_values, show_plot=False, + save_plot=False): + # Set the figure size + plt.figure(figsize=(10,5)) + # Add a title + plt.title("ViT Transformer (ADNI classifier) validation set accuracy") + + # Get the validation accuracy + val_loss = [val_loss_values[i][2] for i in range(len(val_loss_values))] + # Get the validation epochs + val_epoch = [val_loss_values[i][0] for i in range(len(val_loss_values))] + plt.plot(val_epoch, val_loss, label="Validation set", color="Orange") + + # Add axes titles and a legend + plt.xlabel("Number of epochs") + plt.ylabel("Accuracy (%)") + plt.legend() + + # Save the plot + if save_plot: + # Create an output folder for the plot, if one doesn't already exist + directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'plots') + if not os.path.exists(directory): + os.makedirs(directory) + # Save the plot in the "plots" directory + plt.savefig(os.path.join(directory, "ViT_validation_accuracy.png"), dpi=600) + + if show_plot: + # Show the plot + plt.show() + + +""" +Main method - make sure to run any methods in this file within here. +Adding this so that multiprocessing runs appropriately/correctly +on Windows devices. +""" +def main(): + # Train the model + train_model() + + #Create training vs validation loss plots + # train_loss_values = load_training_metrics() + # val_loss_values = load_training_metrics(filename=osp.join(OUTPUT_PATH, 'ADNI_val_loss.csv')) + # plot_loss(train_loss_values, val_loss_values, show_plot=True, save_plot=True) + + # # Create validation accuracy plot + # plot_val_accuracy(val_loss_values, show_plot=True, save_plot=True) + +if __name__ == '__main__': + main() + + + +