Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
091a26a
Initial commit
noam-mendelson Oct 13, 2023
aba7155
Add function to fetch image paths from directory
noam-mendelson Oct 13, 2023
8b8a46a
Add function to extract patient IDs and split data into train/ valida…
noam-mendelson Oct 13, 2023
5787edb
Create Dataset class for loading AD and NC images
noam-mendelson Oct 14, 2023
daf4d41
Data normalisation and augmentation for training
noam-mendelson Oct 15, 2023
8df3880
Organize CViT configuration parameters for clearer instantiation
noam-mendelson Oct 16, 2023
4a9770e
Outlining classes required for model and completion of MultiHeadSelfA…
noam-mendelson Oct 16, 2023
da51c0e
Populate TransformerBlock and ConvolutionalEmbedding modules
noam-mendelson Oct 16, 2023
833537f
Complete populating ConvolutionalEmbedding and ConvolutionalVisionTra…
noam-mendelson Oct 17, 2023
c96ce83
Update data transformations for sets
noam-mendelson Oct 18, 2023
2ca947d
Debugging and congiguration parameters updated
noam-mendelson Oct 19, 2023
b8e06c7
Re-instate docstrings and additional comments from earlier versions (…
noam-mendelson Oct 21, 2023
90633b7
Introduced process.py to handle intricacies of loading and processing…
noam-mendelson Oct 22, 2023
2ab6e2f
Formatting and docstring addition
noam-mendelson Oct 22, 2023
d207b8b
The previous commit mistakenly included a file, train.py, containing …
noam-mendelson Oct 22, 2023
37300bb
Introduce train.py script- Implement training and validation phases f…
noam-mendelson Oct 25, 2023
540aa1c
Populate predict.py script- evaluate CViT model accuracy on test data
noam-mendelson Oct 26, 2023
751fc26
Transferred content written in MS Word to README.md, formatted using …
noam-mendelson Oct 26, 2023
8ca0db5
New implementation of model example usage (predict.py)- replaced olde…
noam-mendelson Oct 26, 2023
fde09ac
process.py is made redundant due to restructuring of scripts- no long…
noam-mendelson Oct 26, 2023
a9b7d89
Error in populating train.py- Populate train.py content. Previous com…
noam-mendelson Oct 26, 2023
193f200
Finalise README.md, include additional CViT image and reword
noam-mendelson Oct 26, 2023
ec6ce9c
Remove __pycache__ and update .gitignore
noam-mendelson Oct 26, 2023
f74b5ab
Delete image_characteristics.py
noam-mendelson Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added Images/CViT.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Images/Training_Validation_Accuracies.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Images/Training_Validation_Losses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Images/vision_transformer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
113 changes: 102 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,106 @@
# Pattern Analysis
Pattern Analysis of various datasets by COMP3710 students at the University of Queensland.
## Vision Transformer for the Classification of Alzheimer’s Disease of ADNI Dataset

We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX.
### Overview/ the Problem

This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students.
I have implemented a Convolutional Neural Network Vision Transformer (CViT), an adaptation of the standard Vision Transformer (ViT) approach to classify Alzheimer’s disease (normal NC and Alzheimer’s Disease AD) within the ADNI dataset. The notion of a CViT was encouraged by the CvT: Introducing Convolutions to Vision Transformers paper ([link](https://arxiv.org/abs/2103.15808)).

The library includes the following implemented in Tensorflow:
* fractals
* recognition problems
### Background/ Model Overview

In the recognition folder, you will find many recognition problems solved including:
* OASIS brain segmentation
* Classification
etc.
Prior to deciding on a CNN-VIT, I thoroughly researched various aspects of transformers to understand their functionality, advantages, and limitations. This led me to the paper ‘Attention Is All You Need’ ([link](https://arxiv.org/abs/1706.03762)), a prevalent paper in the field of deep learning introducing the concept of a transformer.

This paper introduced a revolutionised way of processing sequences, by solely relying on attention mechanisms, dispensing the need for recurrent layers. As suggested by the title, the utilisation of “self-attention” enables the model to weigh the significance of different parts of an input sequence differently. It is able to capture contextual relations between elements (i.e. pixels), regardless of their position in the sequence. This contextual awareness can lead to more accurate classifications, as the model dynamically adjusts the significance it assigns to various input features based on the information it has learned. In the context of AD classification, the model focuses on crucial parts of the input data that are most indicative of the disease. It pains me to admit that this however, was not what drew my eye to this paper, but rather the fact that I shared a name with one of the authors- it’s a rare occurrence for someone named Noam.

#### Vision Transformer (ViT)

A standard ViT breaks down images into fixed-size patches and linearly embeds them as sequences of vectors. This sequence includes an additional ‘class’ token for classification tasks. These sequences are then processed, applying self-attention mechanisms as mentioned above. The output corresponding to the ‘class’ token passes through a feed-forward neural network (FFN) to predict the image’s class. ViT leverages positional embeddings to maintain the image structure information. A visualisation of the ViT articheture can be seen below.

![Visualisation of ViT](Images/vision_transformer.png)

Whilst this model ushers numerous advantages in image processing, ViTs also yield several limitations in the context of image classification tasks:

1. Data Efficiency- ViT work best on large, labelled datasets- outperformed by CNNs on smaller datasets
2. Feature Localisation- ViT treats an image as a sequence of patches, losing explicit local feature representations that are innate to CNNs
3. Computational Efficiency- self-attention mechanism in ViT computes pairwise interactions between patches (computationally expensive for high-resolution images)
4. Fine-grained Feature Sensitivity- ViTs may overlook subtle cues due to patch-based processing (relevant in medical image context)- CNNs capture such details more robustly

Integrating CNNs into the model aims to reduce the impact of these limitations. This is achieved by:

1. Data Efficiency- CViT model can extract hierarchical features better in lower data regimes- leverages inductive biases of CNNs that require less data to generalise well
2. Localisation of Features- Beginning with convolutional layers, CViT maintains advantages of localised feature extraction- enables transformer part to focus on global feature relationships
3. Computational Efficiency- CNN used to reduce spatial resolution (and thus sequence length) prior to transformer stage, making attention computations more manageable
4. Fine-grained Feature Sensitivity- CViT utilise CNN in capturing detailed nuances and global reasoning capabilities of transformers

Merging CNNs and ViTs addresses shortcomings of ViTs by harnessing the strengths of both architectures; it integrates the hierarchal feature learning ability of CNNs with the high-level reasoning capabilities of Transformers. This aims to facilitate the robustness, versatility, and proficiency of the transformer, and hence, outperform traditional ViT or CNN architectures in isolation.

#### Convolutional Vision Transformer (CViT)

My implementation of CViT initiates with a series of convolutional layers, acting as feature extractors; segmenting input images into numerous patches and concurrently learning hierarchical features. This prepares the data for the subsequent transformer architecture, designed to capture complex dependencies and relationships between these patches, irrespective of their spatial positions in the image.

The core of this model, the transformer section, is structured into stages; each of which comprising of one or more transformer blocks. These blocks are integral in handling the model’s reasoning and analytical capabilities, and each block contains multi-head self-attention mechanisms followed by FFNs. This enables the model to focus on different facets of the data and consider various contextual relationships.

Unique to this model is the adaptive configuration of attention heads, MLP ratios, and other hyperparameters across different stages, allowing a more customised approach to learn these hierarchical representations. To ensure the model’s resilience against overfitting and facilitate more stabilised learning, the CViT employs specific regularisation techniques, including layer normalisation and dropout strategies.

The final part of the model compresses the transformer’s output, focusing on to a CLS token and passes it through a linear layer that acts as a classifier. This translates the information from preceding stages into concrete predictions for image classifications.

A visual of this is shown below.
![Visualisation of CViT](Images/CViT.png)

## ADNI Brain Dataset & Pre-Processing

The project utilises images from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset ([ADNI](https://adni.loni.usc.edu)). Each 2D, 256 x 240 pixel image is one of 20 slices from a specific patient’s scan collection.

Data pre-processing is conducted within `dataset.py`, making use of various Pytorch functionalities for streamlined and effective data handling.

#### Dataset.py
- Importing dataset from designated directories
- Constructing a structured DataFrame that organises image file paths alongside their respective labels, providing foundational mapping for subsequent stages
- Implementing patient-level split by extracting unique patient IDs from image paths; the script ensures a non-overlapping distribution of patients between training and validation sets, preserving the integrity of evaluation
- Conducting data augmentation and normalisation to enhance the robustness of the model
- Facilitating batch processing to expedite the computational process (images batched together during training)

## Training and Validation Performance
The model was trained over 25 epochs, with hyperparameters defined in modules.py. The loss and accuracy metrics over both training and validation sets are shown in the below figures.

![Train, val losses](Images/Training_Validation_Losses.png)

![Train, val accuracies](Images/Training_Validation_Accuracies.png)

The model attained an accuracy of 69.4% for the training set and 67.8% for the validation set. The accuracy (and loss) plots indicate that the discrepancy between training and validation set decreases, suggesting that the model generalises well as opposed to memorising the training data. Training the model over additional epochs is likely to further prevent over-fitting, despite attempted prevention in hyperparameters. However, this was not feasible given the additional computational cost associated in doing so. The graphs indicate some convergence, and training the model on additional epochs is expected to enhance this convergence.

## Dependencies
- Python 3.10.12
- PyTorch 2.0.1
- torchvision 0.15.2
- matplotlib 3.7.2
- pandas 2.0.3
- scikit-learn 1.3.0
- Pillow (PIL) 10.0.0

## Testing Environment
GPU access is fundamental for accelerating training and inference processing. This task made use of Google Colab Pro+'s GPU to benefit from its faster GPU and access to more memory, which are much needed for this project.
- **Platform**: Google Colab Pro+
- **GPU**: NVIDIA GPU
- **OS**: Linux (as provided by Google Colab)

## Usage Description
Ensure all dependencies are installed, and access to a GPU or other high-performing machine. To prepare the dataset, data loading and pre-processing is required by making use of the ADNC_Dataset class within dataset.py. Next, loading and splitting of the data via the load_data function prevents overlap between patients in the training and validation sets. Once the data is prepared, the create_data_loaders function is used to create data loaders for the training and validation sets. Once this data handling is complete, the model can be trained, using the train.py script; making use of the data loaders. The number of epochs and batch size can be specified as such --epochs 10 and --batch_size 32. Otherwise, the default is set to 2 epochs and a batch size of 16 (these were not the specifications used in training for this model). To make predictions using a pre-trained model, the predict.py script can be used, by providing the path to the image wanting to be classified. Note that the hyperparameters of the model can be adjusted using the config_params_dict as needed. That is:
- ‘modules.py’ – contains source code for the model, and can be modified if required
- 'dataset.py’ – can be altered to change the way in which data is pre-processed and handled.

## References
- **ADNI dataset**:
- Alzheimer's Disease Neuroimaging Initiative (ADNI) database. [Link](https://adni.loni.usc.edu)

- **ViT architecture image**:
- https://github.com/google-research/vision_transformer

- **Papers**:
- Vaswani, A. et al. (2017). 'Attention Is All You Need'. [Link](https://arxiv.org/abs/1706.03762)
- Wu, Z. et al. (2021). 'CvT: Introducing Convolutions to Vision Transformers'. [Link](https://arxiv.org/abs/2103.15808)

- **Model & Configuration**:
- Convolutional Vision Transformer (CvT) - used for hyperparameters. [HuggingFace Documentation](https://huggingface.co/docs/transformers/model_doc/cvt#transformers.CvtConfig.embed_dim)
- Note: above shows use of pre-trained model

## License
This project is licensed under the terms of the MIT license.
145 changes: 145 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
from torchvision import transforms
from sklearn.model_selection import train_test_split
import pandas as pd
import os
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class ADNC_Dataset(Dataset):
"""
A custom Dataset class for loading AD and NC images.

Attributes:
- image_paths: A list of paths to the image files.
- transforms: Optional transformations to apply to the images.
"""
def __init__(self, AD_image_paths, NC_image_paths, transform=None):
self.AD_image_paths = AD_image_paths
self.NC_image_paths = NC_image_paths

# Creating a DataFrame
AD_df = pd.DataFrame({
'image_path': self.AD_image_paths,
'label': [1]*len(self.AD_image_paths) # 1 for AD
})

NC_df = pd.DataFrame({
'image_path': self.NC_image_paths,
'label': [0]*len(self.NC_image_paths) # 0 for NC
})

self.data = pd.concat([AD_df, NC_df], axis=0).reset_index(drop=True)
#test code
# pd.set_option('display.max_colwidth', None)
# print(self.data.head())
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
"""Returns an image and its label (either 0 or 1)."""
row = self.data.iloc[idx]
image_path = row['image_path']
label = torch.tensor(row['label'])

# Open the image and convert to RGB
image = Image.open(image_path).convert("RGB")

if self.transform:
image = self.transform(image)

return image, label

def get_image_paths_from_directory(directory_path, valid_extensions=[".jpg", ".jpeg", ".png"]):
"""
Get image paths from directory with valid extensions
"""
if not os.path.exists(directory_path):
raise ValueError(f"The provided directory {directory_path} does not exist.")

all_images = []
for image_file in os.listdir(directory_path):
if any(image_file.endswith(ext) for ext in valid_extensions):
image_path = os.path.join(directory_path, image_file)
all_images.append(image_path)
return all_images

def extract_patient_id(image_path):
"""
Extract the patient ID from image path.
"""

base_name = os.path.splitext(os.path.basename(image_path))[0]
return base_name.split('_')[0]

def load_data(train_images_paths_AD, train_images_paths_NC):
"""
Load and split image dataset into training and validation sets whilst ensuring no patient overlap between sets
"""
# Get image paths for training and test datasets

all_train_images_paths_NC = get_image_paths_from_directory(train_images_paths_NC)
all_train_images_paths_AD = get_image_paths_from_directory(train_images_paths_AD)

# Extract unique patient IDs for training and test sets
all_patient_ids_AD = list(set(extract_patient_id(path) for path in all_train_images_paths_AD))
all_patient_ids_NC = list(set(extract_patient_id(path) for path in all_train_images_paths_NC))
# Split patient IDs into training and validation sets (e.g., 80%, 20% split)
train_patient_ids_AD, val_patient_ids_AD = train_test_split(all_patient_ids_AD, test_size=0.20, random_state=42)
train_patient_ids_NC, val_patient_ids_NC = train_test_split(all_patient_ids_NC, test_size=0.20, random_state=42)
# Map patient IDs back to image paths for training and validation sets
train_images_AD = [path for path in all_train_images_paths_AD if extract_patient_id(path) in train_patient_ids_AD]
val_images_AD = [path for path in all_train_images_paths_AD if extract_patient_id(path) in val_patient_ids_AD]
train_images_NC = [path for path in all_train_images_paths_NC if extract_patient_id(path) in train_patient_ids_NC]
val_images_NC = [path for path in all_train_images_paths_NC if extract_patient_id(path) in val_patient_ids_NC]


return train_images_AD, train_images_NC, val_images_AD, val_images_NC

def create_data_loaders(train_images_AD, train_images_NC, val_images_AD, val_images_NC, batch_size):
"""
Create data loaders for training and validation sets with specified transformations.
"""
# Define the data transformation for train, validation, and test
data_transforms = {
'train': transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}

train_dataset = ADNC_Dataset(train_images_AD, train_images_NC, transform=data_transforms['train'])
val_dataset = ADNC_Dataset(val_images_AD, val_images_NC, transform=data_transforms['val'])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

return train_dataloader, val_dataloader

def load_test_data(test_images_paths_AD, test_images_paths_NC):
"""
Loads test data from specified directory and filters patient ID
"""
all_test_images_paths_NC = get_image_paths_from_directory(test_images_paths_NC)
all_test_images_paths_AD = get_image_paths_from_directory(test_images_paths_AD)

all_patient_ids_AD_test = list(set(extract_patient_id(path) for path in all_test_images_paths_AD))
all_patient_ids_NC_test = list(set(extract_patient_id(path) for path in all_test_images_paths_NC))

# Map patient IDs back to image paths for test set
test_images_AD = [path for path in all_test_images_paths_AD if extract_patient_id(path) in all_patient_ids_AD_test]
test_images_NC = [path for path in all_test_images_paths_NC if extract_patient_id(path) in all_patient_ids_NC_test]

return test_images_AD, test_images_NC
Binary file added models.pth
Binary file not shown.
Loading