diff --git a/README.md b/README.md index 4a064f841..e77aaae6e 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,69 @@ -# Pattern Analysis -Pattern Analysis of various datasets by COMP3710 students at the University of Queensland. +# Classifying Alzheimer's Disease Diagnoses Using Vision Trainsformer -We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. +This project aims to categorize the ADNI brain dataset into AD (Alzheimer's Disease) and NC (Normal Cognitive) groups. It employs a Vision Transformer network (ViT) based on the principles presented in the paper. The model was trained using an Adam Optimizer and the parameters were tweaked often to find a good accuracy. Each sample has 20 slices that is 240x256 greyscale image corresponding to a patient, which is to be classified as either NC or AD -This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. +## Dataset Splitting -The library includes the following implemented in Tensorflow: -* fractals -* recognition problems +The dataset is already split into 21,500 images for training, and 9000 images for testing. However, I needed a third split for validation in the dataset.py file. +data_val, data_test = random_split(TensorDataset(xtest, ytest), [0.7,0.3]) +I used random_split for the validation and did a 70/30 split. +I then ended up with 6300 images for validation, and 2700 for testing. -In the recognition folder, you will find many recognition problems solved including: -* OASIS brain segmentation -* Classification -etc. +## Preprocessing the data +The provided code preprocesses the image data by dividing it into patches, applying layer normalization and Multihead Attention mechanisms, and incorporating positional encoding before utilizing the Vision Transformer + +## Training the data +These were the following parameters used for training. I didnt need a parameter for number of channels as we were only dealing with black and white data. +vit = VisionTransformer(input_dimen=128, + hiddenlayer_dimen=256, + number_heads=4, + transform_layers=4, + predict_num=2, + size_patch=(16,16)) +input_dimen - Dimensionality of the input feature vectors to the Transformer +hiddenlayer_dimen - Dimensionality of the hidden layer in the feed-forward networks within the Transformer +number_heads - Number of heads to use in the Multi-Head Attention block +transform_layers - Number of layers to use in the Transformer +predict_num - Number of classes to predict +size_patch - Number of pixels that the patches have per dimension + +The time taken to finish training depended on the parameters. +Using adam optimizer and learning rate = 1e-4 and 75 epoch, I had accuracy of 0.68 ( 5.5 hours ) +With adamW optimizer and learning rate = 3e-4 and 100 epoch, I had a low accuracy of 0.53 ( 7 hours ) + +## Configuration +All main configurations would be done in the train.py file +In the train function there is this: + optimizer = optim.AdamW(net.parameters(), lr=3e-4) + epochs = 100 +You can change between optimizers, learning rate and epoch value in here +Also in the end of the train.py file, there is the VIT. + +vit = VisionTransformer(input_dimen=128, + hiddenlayer_dimen=256, + number_heads=4, + transform_layers=4, + predict_num=2, + size_patch=(16,16)) + +## Results +These are the results: +loss vs epoch graph- ![image](https://github.com/HaadiQureshi/VIT-46878467/assets/141606798/64605a94-429c-4dc8-b5fd-8e4e10276942) + + +Accuracy vs epoch graph - image + + + +## How to use +The project consists of four essential files, namely dataset.py, modules.py, train.py, and predict.py. The primary files to be executed are train.py and predict.py. The train.py file handles the training and testing of the model, allowing the option to save the model, along with recording the loss and validation accuracy for each epoch. This data is utilized by predict.py. Predict.py evaluates the actual output data as it can generate graphs depicting the loss and accuracy curves using the matplotlib library. + + + +Key considerations: +1. Inside the dataset.py file, script loads, preprocesses, and organizes medical image data from specific directories, converting the images to tensors, dividing them into training and testing sets with corresponding labels, and creating data loaders for training, testing, and validation. +2. in train.py script imports required libraries, modules, and functions, then loads the data using returnDataLoaders from the dataset.py file. It defines an empty list for storing losses and accuracies, sets up a training function that utilizes the AdamW optimizer and CrossEntropyLoss +3. In the predict.py script, I plot two separate graphs. The first graph illustrates the accuracy vs epoch, displaying the trend of the model's accuracy over the training epochs. The second graph demonstrates the loss vs epoch, showcasing how the training loss varies throughout the training process. +4. The modules.py file contains functions and classes for implementing a Vision Transformer model, including an image patching function, an attention block class for multi-head attention, and a VisionTransformer class that applies linear transformations, positional embeddings +# URL +https://github.com/HaadiQureshi/VIT-46878467.git diff --git a/dataset.py b/dataset.py new file mode 100644 index 000000000..95c6a1414 --- /dev/null +++ b/dataset.py @@ -0,0 +1,89 @@ +import numpy as np +import torch +from PIL import Image +import os +import torchvision.transforms as transforms +from torch.utils.data import TensorDataset, DataLoader, random_split + +transform = transforms.Compose([ + transforms.PILToTensor() +]) + +xtrain = [] +xtest = [] +ytrain = [] +ytest = [] +slicemax = 20 #20 images per patient + +ntrainimgs_AD = 0 +patient = [] +slice = 0 +for filename in sorted(os.listdir('../ADNI_AD_NC_2D/AD_NC/train/AD/')): + f = os.path.join('../ADNI_AD_NC_2D/AD_NC/train/AD/', filename) + img = Image.open(f) + imgtorch = transform(img).float() + imgtorch.require_grad = True + patient.append(imgtorch/255) #go from 0,255 to 0,1 + slice = (slice+1) % slicemax + if slice == 0: + xtrain.append(torch.stack(patient)) + patient = [] + ntrainimgs_AD += 1 + pass +ntrainimgs_NC = 0 +patient = [] +slice = 0 +for filename in sorted(os.listdir('../ADNI_AD_NC_2D/AD_NC/train/NC')): + f = os.path.join('../ADNI_AD_NC_2D/AD_NC/train/NC', filename) + img = Image.open(f) + imgtorch = transform(img).float() + imgtorch.require_grad = True + patient.append(imgtorch/255) #go from 0,255 to 0,1 + slice = (slice+1) % slicemax + if slice == 0: + xtrain.append(torch.stack(patient)) + patient = [] + ntrainimgs_NC += 1 + pass +ntestimgs_AD = 0 +patient = [] +slice = 0 +for filename in sorted(os.listdir('../ADNI_AD_NC_2D/AD_NC/test/AD')): + f = os.path.join('../ADNI_AD_NC_2D/AD_NC/test/AD', filename) + img = Image.open(f) + imgtorch = transform(img).float() + imgtorch.require_grad = True + patient.append(imgtorch/255) #go from 0,255 to 0,1 + slice = (slice+1) % slicemax + if slice == 0: + xtest.append(torch.stack(patient)) + patient = [] + ntestimgs_AD += 1 + pass +ntestimgs_NC = 0 +patient = [] +slice = 0 +for filename in sorted(os.listdir('../ADNI_AD_NC_2D/AD_NC/test/NC')): + f = os.path.join('../ADNI_AD_NC_2D/AD_NC/test/NC', filename) + img = Image.open(f) + imgtorch = transform(img).float() + imgtorch.require_grad = True + patient.append(imgtorch/255) #go from 0,255 to 0,1 + slice = (slice+1) % slicemax + if slice == 0: + xtest.append(torch.stack(patient)) + patient = [] + ntestimgs_NC += 1 + pass +xtrain = torch.stack(xtrain) +xtest = torch.stack(xtest) +ytrain = torch.from_numpy(np.concatenate((np.ones(ntrainimgs_AD), np.zeros(ntrainimgs_NC)), axis=0)).type(torch.LongTensor) +ytest = torch.from_numpy(np.concatenate((np.ones(ntestimgs_AD), np.zeros(ntestimgs_NC)), axis=0)).type(torch.LongTensor) + +data_val, data_test = random_split(TensorDataset(xtest, ytest), [0.7,0.3]) +dataloader_train = DataLoader(TensorDataset(xtrain, ytrain), batch_size=32, shuffle=True) +dataloader_test = DataLoader(data_test, batch_size=32, shuffle=True) +dataloader_val = DataLoader(data_val, batch_size=32, shuffle=True) + +def returnDataLoaders(): + return dataloader_train, dataloader_test, dataloader_val \ No newline at end of file diff --git a/modules.py b/modules.py new file mode 100644 index 000000000..508926a73 --- /dev/null +++ b/modules.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn + +def image_patcher(image,size_patch, patch_depth): + Batch_Size, Depth, C, Height, Width = image.shape + # change the shape of the tensor + Height_final = Height // size_patch + Width_final = Width // size_patch + Depth_final = Depth // patch_depth + image = image.reshape(Batch_Size, Depth_final, patch_depth, C,Height_final,size_patch,Width_final,size_patch) + #permute the dimensions of the tensor + image = image.permute(0, 1, 4, 6, 3, 2, 5, 7) + #flatten specific dimensions of the tensor + image = image.flatten(1, 3).flatten(2, 5) + return image + +class AttentionBlock(nn.Module): + def __init__(self,input_dimen,hiddenlayer_dimen,number_heads): + super().__init__() + #layer normalization is applied to the input data + self.input_layer_norm = nn.LayerNorm(input_dimen) + #normalizes the output of the attention mechanism. + self.output_layer_norm = nn.LayerNorm(input_dimen) + # block with multiple attention heads. + self.multihead_attention = nn.MultiheadAttention(input_dimen,number_heads) + self.linear = nn.Sequential(nn.Linear(input_dimen,hiddenlayer_dimen),nn.GELU(), + nn.Linear(hiddenlayer_dimen,input_dimen), + ) + + def forward(self,image): + inp_x = self.input_layer_norm(image) + add = self.multihead_attention(inp_x, inp_x, inp_x)[0] + image = image + add + image = image + self.linear(self.output_layer_norm(image)) + return image + +class VisionTransformer(nn.Module): + def __init__( + self,input_dimen,hiddenlayer_dimen,number_heads,transform_layers,predict_num,size_patch + ): + super().__init__() + (size_patch_x, size_patch_y) = size_patch + + self.size_patch = size_patch_x * size_patch_y + #creates an instance of the nn.linear + self.input_layer = nn.Linear(5*self.size_patch, input_dimen) + #creates an instance of nn.sequential + self.final_transform = nn.Sequential(*(AttentionBlock(input_dimen, hiddenlayer_dimen, number_heads) for _ in range(transform_layers))) + + self.dense_head = nn.Sequential(nn.LayerNorm(input_dimen), nn.Linear(input_dimen, predict_num)) + final_num_patch = 1 + (240 // size_patch_x)*(256 // size_patch_y) + self.positional_emb = nn.Parameter(torch.randn(1,4*final_num_patch,input_dimen)) + self.classification_tkn = nn.Parameter(torch.randn(1,1,input_dimen)) + + + def forward(self, image): + # input being preprocessed + image = image_patcher(image, 16, 5) + Batch_Size, x, _ = image.shape + + image = self.input_layer(image) + + # Add a positional encoding and a CLS token + classification_tkn = self.classification_tkn.repeat(Batch_Size, 1, 1) + image = torch.cat([classification_tkn, image], dim=1) + image = image + self.positional_emb[:, : x + 1] + + #this adds a final_transform + image = image.transpose(0, 1) + image = self.final_transform(image) + class_ = image[0] + out = self.dense_head(class_) + return out \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 000000000..3596a3e94 --- /dev/null +++ b/predict.py @@ -0,0 +1,29 @@ +import matplotlib.pyplot as plt + +#This is the accuracies of the first 100 epoch +accuracies = [0.4962, 0.5274, 0.5917, 0.4944, 0.5124, 0.5476, 0.5922, 0.5431, 0.5013, 0.5294, 0.5464, 0.5487, 0.5922, 0.6340, 0.5876, 0.6005, 0.6104, 0.6167, 0.6243, 0.6215, 0.6473, 0.6436, 0.5936, 0.6391, 0.6030, 0.6030, 0.5675, 0.6116, 0.6212, 0.5968, 0.5843, 0.6056, 0.6073, 0.6175, 0.5985, 0.5948, 0.6186, 0.5752, 0.6056, 0.5769, 0.6042, 0.6212, 0.5825, 0.5931, 0.5786, 0.6002, 0.5712, 0.5624, 0.5907, 0.5848, 0.6260, 0.6030, 0.5854, 0.5819, 0.6329, 0.6042, 0.6204, 0.6106, 0.6110, 0.6112, 0.6126, 0.6126, 0.6135, 0.6098, 0.6081, 0.6087, 0.6130, 0.6167, 0.6118, 0.6130, 0.6135, 0.6124, 0.6118, 0.6135, 0.6124, 0.6106, 0.6141, 0.6147, 0.6141, 0.6112, 0.6118, 0.6130, 0.6153, 0.6135, 0.6130, 0.6101, 0.6124, 0.6135, 0.6141, 0.6095, 0.6147, 0.6112, 0.6118, 0.6093, 0.6159, 0.6098, 0.6087, 0.6093, 0.6093, 0.6093] +#This is the loss for the first 100 epoch +loss = [0.6979881910716786, 0.6844412877279169, 0.6904780426446129, 0.684895454084172, 0.6812622880234438, 0.6792347308467416, 0.6771022852729348, 0.6855413440395804, 0.6720500413109275, 0.6870778799057007, 0.6519009260570302, 0.6404731641797459, 0.6227208665188622, 0.5900690169895396, 0.5951909887440064, 0.5607741247205174, 0.5638843687141643, 0.5399825283709694, 0.5193865106386297, 0.4769795151317821, 0.46142111543346853, 0.45512034174273996, 0.4230612568995532, 0.3895933969932444, 0.39457252358689027, 0.3760155246538274, 0.37903575160924124, 0.3169827342909925, 0.3307492321028429, 0.25006428689641114, 0.2171676978468895, 0.284952069468358, 0.19351636508808417, 0.19536656217978282, 0.1634677432696609, 0.11777753673274727, 0.14037292516406844, 0.23939100105096311, 0.11117816716432571, 0.06792027106070343, 0.11127064086715965, 0.08680430388845065, 0.08349954084876705, 0.0602918595815187, 0.04537964773172622, 0.018756276154068902, 0.017304050311555758, 0.0667554905932561, 0.056819359415813404, 0.01601847937426475, 0.0256724186885335, 0.08536267633248559, 0.016678674338275894, 0.021344836472588426, 0.03200960334951935, 0.054271318350562495, 0.032041940687443406, 0.008468315467539737, 0.0025479644400012843, 0.0009727750567596077, 0.0007114587334559902, 0.0006080651262035483, 0.0005404480839120772, 0.0004825884388992563, 0.00043984228036339013, 0.0004121263824773076, 0.00037888841025586077, 0.0003548304273007328, 0.0003334864805390894, 0.0003150697696529438, 0.0002981368049992906, 0.00028354384695001715, 0.00027058320034377496, 0.0002587148782742374, 0.00024809086993199717, 0.00023693855730337366, 0.0002282507754417191, 0.00022031878153725034, 0.00021815271654358024, 0.0002045452338814571, 0.00019706551778226103, 0.00019002843627651388, 0.0001844407169675619, 0.00017817121774629305, 0.00017225533241905985, 0.00016851070519324448, 0.00016327866144231795, 0.0001579179693448275, 0.00015318737795870917, 0.00014883789652444916, 0.00014545178911409013, 0.00014506131992675364, 0.00013773206635104383, 0.00013385336698026068, 0.00013058477484532085, 0.00012730956672492218, 0.00012449162686072455, 0.00012149684548871043, 0.0001185430414539844, 0.00011599305349484305] + + +epoch = list(range(100)) + + +plt.figure(figsize=(12, 6)) +plt.plot(epoch, accuracies, marker='o', linestyle='-', color='b', label='Accuracy') +plt.title('Accuracy vs Epoch') +plt.xlabel('Epoch') +plt.ylabel('Accuracy') +plt.legend() +plt.grid(True) +plt.show() + + + +plt.plot(range(len(loss)), loss, marker='o', linestyle='-', color='b', label='Loss') +plt.title('Loss vs Epoch') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.legend() +plt.grid() +plt.show() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 000000000..eeaa50abb --- /dev/null +++ b/train.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +from dataset import returnDataLoaders +from modules import * +import torch.optim as optim + +dataloader_train, dataloader_test, dataloader_val = returnDataloaders() + +losses = [] +accuracies = [] + +def train(net, dataloader_train, dataloader_val, cross_entropy): + optimizer = optim.Adam(net.parameters(), lr=2e-4) + epochs = 100 + # training loop + for epoch in range(epochs): + epoch_loss = 0 + net.train() + for (x_batch, y_batch) in dataloader_train: # for each mini-batch + optimizer.zero_grad() + loss = cross_entropy(net.forward(x_batch), y_batch) + loss.backward() + optimizer.step() + epoch_loss += loss.detach().item() + epoch_loss = epoch_loss / len(dataloader_train) + losses.append(epoch_loss) + + net.eval() + acc = test(net, dataloader_val) + print("epoch:", epoch, "accuracy:", acc, "loss:", epoch_loss, flush=True) + accuracies.append(acc) + +def test(net, dataloader_val, batch_size=16): + with torch.no_grad(): + acc = 0 + for (x_batch, y_batch) in dataloader_val: + acc += torch.sum((y_batch == torch.max(net(x_batch).detach(), 1)[1]), axis=0)/len(y_batch) + acc = acc/len(dataloader_val) + return acc + +vit = VisionTransformer(input_dimen=128, + hiddenlayer_dimen=256, + number_heads=4, + transform_layers=4, + predict_num=2, + size_patch=(16,16)) +cross_entropy = nn.CrossEntropyLoss() +train(vit, dataloader_train, dataloader_val, cross_entropy)