diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 000000000..d47954f6c Binary files /dev/null and b/.DS_Store differ diff --git a/Images/CViT.png b/Images/CViT.png new file mode 100644 index 000000000..ea5a14d98 Binary files /dev/null and b/Images/CViT.png differ diff --git a/Images/Training_Validation_Accuracies.png b/Images/Training_Validation_Accuracies.png new file mode 100644 index 000000000..1dea2b279 Binary files /dev/null and b/Images/Training_Validation_Accuracies.png differ diff --git a/Images/Training_Validation_Losses.png b/Images/Training_Validation_Losses.png new file mode 100644 index 000000000..a2f83b4b6 Binary files /dev/null and b/Images/Training_Validation_Losses.png differ diff --git a/Images/vision_transformer.png b/Images/vision_transformer.png new file mode 100644 index 000000000..94930cb88 Binary files /dev/null and b/Images/vision_transformer.png differ diff --git a/README.md b/README.md index 4a064f841..c97463213 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,106 @@ -# Pattern Analysis -Pattern Analysis of various datasets by COMP3710 students at the University of Queensland. +## Vision Transformer for the Classification of Alzheimer’s Disease of ADNI Dataset -We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. +### Overview/ the Problem -This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. +I have implemented a Convolutional Neural Network Vision Transformer (CViT), an adaptation of the standard Vision Transformer (ViT) approach to classify Alzheimer’s disease (normal NC and Alzheimer’s Disease AD) within the ADNI dataset. The notion of a CViT was encouraged by the CvT: Introducing Convolutions to Vision Transformers paper ([link](https://arxiv.org/abs/2103.15808)). -The library includes the following implemented in Tensorflow: -* fractals -* recognition problems +### Background/ Model Overview -In the recognition folder, you will find many recognition problems solved including: -* OASIS brain segmentation -* Classification -etc. +Prior to deciding on a CNN-VIT, I thoroughly researched various aspects of transformers to understand their functionality, advantages, and limitations. This led me to the paper ‘Attention Is All You Need’ ([link](https://arxiv.org/abs/1706.03762)), a prevalent paper in the field of deep learning introducing the concept of a transformer. + +This paper introduced a revolutionised way of processing sequences, by solely relying on attention mechanisms, dispensing the need for recurrent layers. As suggested by the title, the utilisation of “self-attention” enables the model to weigh the significance of different parts of an input sequence differently. It is able to capture contextual relations between elements (i.e. pixels), regardless of their position in the sequence. This contextual awareness can lead to more accurate classifications, as the model dynamically adjusts the significance it assigns to various input features based on the information it has learned. In the context of AD classification, the model focuses on crucial parts of the input data that are most indicative of the disease. It pains me to admit that this however, was not what drew my eye to this paper, but rather the fact that I shared a name with one of the authors- it’s a rare occurrence for someone named Noam. + +#### Vision Transformer (ViT) + +A standard ViT breaks down images into fixed-size patches and linearly embeds them as sequences of vectors. This sequence includes an additional ‘class’ token for classification tasks. These sequences are then processed, applying self-attention mechanisms as mentioned above. The output corresponding to the ‘class’ token passes through a feed-forward neural network (FFN) to predict the image’s class. ViT leverages positional embeddings to maintain the image structure information. A visualisation of the ViT articheture can be seen below. + +![Visualisation of ViT](Images/vision_transformer.png) + +Whilst this model ushers numerous advantages in image processing, ViTs also yield several limitations in the context of image classification tasks: + +1. Data Efficiency- ViT work best on large, labelled datasets- outperformed by CNNs on smaller datasets +2. Feature Localisation- ViT treats an image as a sequence of patches, losing explicit local feature representations that are innate to CNNs +3. Computational Efficiency- self-attention mechanism in ViT computes pairwise interactions between patches (computationally expensive for high-resolution images) +4. Fine-grained Feature Sensitivity- ViTs may overlook subtle cues due to patch-based processing (relevant in medical image context)- CNNs capture such details more robustly + +Integrating CNNs into the model aims to reduce the impact of these limitations. This is achieved by: + +1. Data Efficiency- CViT model can extract hierarchical features better in lower data regimes- leverages inductive biases of CNNs that require less data to generalise well +2. Localisation of Features- Beginning with convolutional layers, CViT maintains advantages of localised feature extraction- enables transformer part to focus on global feature relationships +3. Computational Efficiency- CNN used to reduce spatial resolution (and thus sequence length) prior to transformer stage, making attention computations more manageable +4. Fine-grained Feature Sensitivity- CViT utilise CNN in capturing detailed nuances and global reasoning capabilities of transformers + +Merging CNNs and ViTs addresses shortcomings of ViTs by harnessing the strengths of both architectures; it integrates the hierarchal feature learning ability of CNNs with the high-level reasoning capabilities of Transformers. This aims to facilitate the robustness, versatility, and proficiency of the transformer, and hence, outperform traditional ViT or CNN architectures in isolation. + +#### Convolutional Vision Transformer (CViT) + +My implementation of CViT initiates with a series of convolutional layers, acting as feature extractors; segmenting input images into numerous patches and concurrently learning hierarchical features. This prepares the data for the subsequent transformer architecture, designed to capture complex dependencies and relationships between these patches, irrespective of their spatial positions in the image. + +The core of this model, the transformer section, is structured into stages; each of which comprising of one or more transformer blocks. These blocks are integral in handling the model’s reasoning and analytical capabilities, and each block contains multi-head self-attention mechanisms followed by FFNs. This enables the model to focus on different facets of the data and consider various contextual relationships. + +Unique to this model is the adaptive configuration of attention heads, MLP ratios, and other hyperparameters across different stages, allowing a more customised approach to learn these hierarchical representations. To ensure the model’s resilience against overfitting and facilitate more stabilised learning, the CViT employs specific regularisation techniques, including layer normalisation and dropout strategies. + +The final part of the model compresses the transformer’s output, focusing on to a CLS token and passes it through a linear layer that acts as a classifier. This translates the information from preceding stages into concrete predictions for image classifications. + +A visual of this is shown below. +![Visualisation of CViT](Images/CViT.png) + +## ADNI Brain Dataset & Pre-Processing + +The project utilises images from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset ([ADNI](https://adni.loni.usc.edu)). Each 2D, 256 x 240 pixel image is one of 20 slices from a specific patient’s scan collection. + +Data pre-processing is conducted within `dataset.py`, making use of various Pytorch functionalities for streamlined and effective data handling. + +#### Dataset.py +- Importing dataset from designated directories +- Constructing a structured DataFrame that organises image file paths alongside their respective labels, providing foundational mapping for subsequent stages +- Implementing patient-level split by extracting unique patient IDs from image paths; the script ensures a non-overlapping distribution of patients between training and validation sets, preserving the integrity of evaluation +- Conducting data augmentation and normalisation to enhance the robustness of the model +- Facilitating batch processing to expedite the computational process (images batched together during training) + +## Training and Validation Performance +The model was trained over 25 epochs, with hyperparameters defined in modules.py. The loss and accuracy metrics over both training and validation sets are shown in the below figures. + +![Train, val losses](Images/Training_Validation_Losses.png) + +![Train, val accuracies](Images/Training_Validation_Accuracies.png) + +The model attained an accuracy of 69.4% for the training set and 67.8% for the validation set. The accuracy (and loss) plots indicate that the discrepancy between training and validation set decreases, suggesting that the model generalises well as opposed to memorising the training data. Training the model over additional epochs is likely to further prevent over-fitting, despite attempted prevention in hyperparameters. However, this was not feasible given the additional computational cost associated in doing so. The graphs indicate some convergence, and training the model on additional epochs is expected to enhance this convergence. + +## Dependencies +- Python 3.10.12 +- PyTorch 2.0.1 +- torchvision 0.15.2 +- matplotlib 3.7.2 +- pandas 2.0.3 +- scikit-learn 1.3.0 +- Pillow (PIL) 10.0.0 + +## Testing Environment +GPU access is fundamental for accelerating training and inference processing. This task made use of Google Colab Pro+'s GPU to benefit from its faster GPU and access to more memory, which are much needed for this project. +- **Platform**: Google Colab Pro+ +- **GPU**: NVIDIA GPU +- **OS**: Linux (as provided by Google Colab) + +## Usage Description +Ensure all dependencies are installed, and access to a GPU or other high-performing machine. To prepare the dataset, data loading and pre-processing is required by making use of the ADNC_Dataset class within dataset.py. Next, loading and splitting of the data via the load_data function prevents overlap between patients in the training and validation sets. Once the data is prepared, the create_data_loaders function is used to create data loaders for the training and validation sets. Once this data handling is complete, the model can be trained, using the train.py script; making use of the data loaders. The number of epochs and batch size can be specified as such --epochs 10 and --batch_size 32. Otherwise, the default is set to 2 epochs and a batch size of 16 (these were not the specifications used in training for this model). To make predictions using a pre-trained model, the predict.py script can be used, by providing the path to the image wanting to be classified. Note that the hyperparameters of the model can be adjusted using the config_params_dict as needed. That is: + - ‘modules.py’ – contains source code for the model, and can be modified if required + - 'dataset.py’ – can be altered to change the way in which data is pre-processed and handled. + +## References +- **ADNI dataset**: + - Alzheimer's Disease Neuroimaging Initiative (ADNI) database. [Link](https://adni.loni.usc.edu) + +- **ViT architecture image**: + - https://github.com/google-research/vision_transformer + +- **Papers**: + - Vaswani, A. et al. (2017). 'Attention Is All You Need'. [Link](https://arxiv.org/abs/1706.03762) + - Wu, Z. et al. (2021). 'CvT: Introducing Convolutions to Vision Transformers'. [Link](https://arxiv.org/abs/2103.15808) + +- **Model & Configuration**: + - Convolutional Vision Transformer (CvT) - used for hyperparameters. [HuggingFace Documentation](https://huggingface.co/docs/transformers/model_doc/cvt#transformers.CvtConfig.embed_dim) + - Note: above shows use of pre-trained model + +## License +This project is licensed under the terms of the MIT license. \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 000000000..59fc4de82 --- /dev/null +++ b/dataset.py @@ -0,0 +1,145 @@ +import torch +from torchvision import transforms +from sklearn.model_selection import train_test_split +import pandas as pd +import os +import random +from torch.utils.data import Dataset, DataLoader +from PIL import Image + +class ADNC_Dataset(Dataset): + """ + A custom Dataset class for loading AD and NC images. + + Attributes: + - image_paths: A list of paths to the image files. + - transforms: Optional transformations to apply to the images. + """ + def __init__(self, AD_image_paths, NC_image_paths, transform=None): + self.AD_image_paths = AD_image_paths + self.NC_image_paths = NC_image_paths + + # Creating a DataFrame + AD_df = pd.DataFrame({ + 'image_path': self.AD_image_paths, + 'label': [1]*len(self.AD_image_paths) # 1 for AD + }) + + NC_df = pd.DataFrame({ + 'image_path': self.NC_image_paths, + 'label': [0]*len(self.NC_image_paths) # 0 for NC + }) + + self.data = pd.concat([AD_df, NC_df], axis=0).reset_index(drop=True) + #test code + # pd.set_option('display.max_colwidth', None) + # print(self.data.head()) + self.transform = transform + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + """Returns an image and its label (either 0 or 1).""" + row = self.data.iloc[idx] + image_path = row['image_path'] + label = torch.tensor(row['label']) + + # Open the image and convert to RGB + image = Image.open(image_path).convert("RGB") + + if self.transform: + image = self.transform(image) + + return image, label + +def get_image_paths_from_directory(directory_path, valid_extensions=[".jpg", ".jpeg", ".png"]): + """ + Get image paths from directory with valid extensions + """ + if not os.path.exists(directory_path): + raise ValueError(f"The provided directory {directory_path} does not exist.") + + all_images = [] + for image_file in os.listdir(directory_path): + if any(image_file.endswith(ext) for ext in valid_extensions): + image_path = os.path.join(directory_path, image_file) + all_images.append(image_path) + return all_images + +def extract_patient_id(image_path): + """ + Extract the patient ID from image path. + """ + + base_name = os.path.splitext(os.path.basename(image_path))[0] + return base_name.split('_')[0] + +def load_data(train_images_paths_AD, train_images_paths_NC): + """ + Load and split image dataset into training and validation sets whilst ensuring no patient overlap between sets + """ + # Get image paths for training and test datasets + + all_train_images_paths_NC = get_image_paths_from_directory(train_images_paths_NC) + all_train_images_paths_AD = get_image_paths_from_directory(train_images_paths_AD) + + # Extract unique patient IDs for training and test sets + all_patient_ids_AD = list(set(extract_patient_id(path) for path in all_train_images_paths_AD)) + all_patient_ids_NC = list(set(extract_patient_id(path) for path in all_train_images_paths_NC)) + # Split patient IDs into training and validation sets (e.g., 80%, 20% split) + train_patient_ids_AD, val_patient_ids_AD = train_test_split(all_patient_ids_AD, test_size=0.20, random_state=42) + train_patient_ids_NC, val_patient_ids_NC = train_test_split(all_patient_ids_NC, test_size=0.20, random_state=42) + # Map patient IDs back to image paths for training and validation sets + train_images_AD = [path for path in all_train_images_paths_AD if extract_patient_id(path) in train_patient_ids_AD] + val_images_AD = [path for path in all_train_images_paths_AD if extract_patient_id(path) in val_patient_ids_AD] + train_images_NC = [path for path in all_train_images_paths_NC if extract_patient_id(path) in train_patient_ids_NC] + val_images_NC = [path for path in all_train_images_paths_NC if extract_patient_id(path) in val_patient_ids_NC] + + + return train_images_AD, train_images_NC, val_images_AD, val_images_NC + +def create_data_loaders(train_images_AD, train_images_NC, val_images_AD, val_images_NC, batch_size): + """ + Create data loaders for training and validation sets with specified transformations. + """ + # Define the data transformation for train, validation, and test + data_transforms = { + 'train': transforms.Compose([ + transforms.Resize(299), + transforms.CenterCrop(299), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + 'val': transforms.Compose([ + transforms.Resize(299), + transforms.CenterCrop(299), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + } + + train_dataset = ADNC_Dataset(train_images_AD, train_images_NC, transform=data_transforms['train']) + val_dataset = ADNC_Dataset(val_images_AD, val_images_NC, transform=data_transforms['val']) + + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + + return train_dataloader, val_dataloader + +def load_test_data(test_images_paths_AD, test_images_paths_NC): + """ + Loads test data from specified directory and filters patient ID + """ + all_test_images_paths_NC = get_image_paths_from_directory(test_images_paths_NC) + all_test_images_paths_AD = get_image_paths_from_directory(test_images_paths_AD) + + all_patient_ids_AD_test = list(set(extract_patient_id(path) for path in all_test_images_paths_AD)) + all_patient_ids_NC_test = list(set(extract_patient_id(path) for path in all_test_images_paths_NC)) + + # Map patient IDs back to image paths for test set + test_images_AD = [path for path in all_test_images_paths_AD if extract_patient_id(path) in all_patient_ids_AD_test] + test_images_NC = [path for path in all_test_images_paths_NC if extract_patient_id(path) in all_patient_ids_NC_test] + + return test_images_AD, test_images_NC diff --git a/models.pth b/models.pth new file mode 100644 index 000000000..40bb4ab90 Binary files /dev/null and b/models.pth differ diff --git a/modules.py b/modules.py new file mode 100644 index 000000000..664f43234 --- /dev/null +++ b/modules.py @@ -0,0 +1,282 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +def configuration(): + # Initialize the model, loss function, and optimizer + config_params_dict = { + "general": { + "num_channels": 3, # RGB + "num_classes": 2 + }, + "num_classes": 2, + "patches": { + "sizes": [7, 3, 3], # kernel size of each encoder’s patch embedding. + "strides": [4, 2, 2], # stride size ^^ + "padding": [2, 1, 1] + }, + "transformer": { + "embed_dim": [64, 192, 384], + "hidden_size": 384, #no. of features in the hidden state + "num_heads": [2, 4, 6], # Matching the number of blocks + "depth": [1, 1, 1], # Adjust this according to the number of blocks + "mlp_ratios": [4.0, 4.0, 4.0, 4.0], # size of the hidden layer: size of the input layer + "attention_drop_rate": [0.0, 0.0, 0.0], + "drop_rate": [0.0, 0.0, 0.0], + "drop_path_rate": [0.0, 0.0, 0.1], + "qkv": { # queries (q), keys (k), and values (v) + "bias": [True, True, True], + "projection_method": ["dw_bn", "dw_bn", "dw_bn"], + "kernel": [3, 3, 3], + "padding": { + "kv": [1, 1, 1], + "q": [1, 1, 1] + }, + "stride": { + "kv": [2, 2, 2], + "q": [1, 1, 1] + } + }, + "cls_token": [False, False, True] + }, + "initialisation": { + "range": 0.02, + "layer_norm_eps": 1e-6 + } + } + return config_params_dict +class CViTConfig: + """ + Configuration class for Convolutional Vision Transformer (CViT) containing the configuration of the + CvT- used to instantiate the model with specific architecture parameters. + """ + def __init__(self, config_params): + for key, value in config_params.items(): + setattr(self, key, value) + +class MultiHeadSelfAttention(nn.Module): + """ + Implements the multi-head self-attention mechanism. + The attention mechanism uses scaled dot-product attention- operates on qkv projection of the input. + """ + def __init__(self, config): + super().__init__() + num_heads = config.transformer['num_heads'][0] + hidden_size = config.transformer['hidden_size'] + self.head_dim = hidden_size // num_heads # calculate dimension of each head + + if hidden_size % num_heads != 0: + raise ValueError(f"hidden_size ({hidden_size}) must be divisible by the number of heads ({num_heads})") + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.scaling = self.head_dim ** -0.5 # scaling factor for the dot product attention + + # Setting up the query, key, and value linear projection layers + self.query_projection = nn.Linear(hidden_size, hidden_size) + self.key_projection = nn.Linear(hidden_size, hidden_size) + self.value_projection = nn.Linear(hidden_size, hidden_size) + + # Define the `all_head_size` attribute + self.all_head_size = self.head_dim * self.num_heads + + # Output projection layer - takes concatenated output of all attention heads and projects back to the model dimension + self.output_projection = nn.Linear(hidden_size, hidden_size) + + self.dropout = nn.Dropout(config.transformer['attention_drop_rate'][0]) # prevent overfitting + self.layer_norm = nn.LayerNorm(hidden_size, eps=config.initialisation['layer_norm_eps']) # stability + + def transpose_for_scores(self, x): + """ + Reshapes the 'x' tensor to separate the different attention heads- preparing it for the attention calculation. + """ + new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim) # size adjusted based on the number of attention heads and the size of each head. + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) # permute to get shape [batch_size, num_heads, seq_length, head_dim] + + def forward(self, hidden_states): + # Linear operations on input + mixed_query_layer = self.query_projection(hidden_states) + mixed_key_layer = self.key_projection(hidden_states) + mixed_value_layer = self.value_projection(hidden_states) + + # Transpose for multi-head attention and apply attention mechanism. + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Attention score calculation. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores.mul_(self.scaling) + + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # Dropout - helps prevent overfitting + attention_probs = self.dropout(attention_probs) + + # Context layer - weighted sum of the value layer based on attention probabilities. + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # Reshape + + output = self.output_projection(context_layer) # Projecting back to the original dimension + return output + + +class TransformerBlock(nn.Module): + """ + Transformer Block module comprising of multi-head self-attention mechanism and position-wise feed-forward network (FFN) + """ + def __init__(self, config, index): + super(TransformerBlock, self).__init__() + + # Extracting the configuration parameters based on the block's index + hidden_size = config.transformer['hidden_size'] + num_heads = config.transformer['num_heads'][index] # Corrected to use 'index' + dropout_rate = config.transformer['drop_rate'][index] + mlp_ratio = config.transformer['mlp_ratios'][index] + attention_dropout_rate = config.transformer['attention_drop_rate'][index] + + # Ensure the division is integer + self.attention_head_size = int(hidden_size // num_heads) + self.all_head_size = num_heads * self.attention_head_size + + # Layer for MultiHeadSelfAttention + self.self_attention = MultiHeadSelfAttention(config) # Pass the 'config' object here + self.attention_output_dropout = nn.Dropout(attention_dropout_rate) + self.attention_output_layer_norm = nn.LayerNorm(hidden_size, eps=config.initialisation['layer_norm_eps']) + + # Parameters for the feed-forward network (FFN) + self.ffn_output_layer_norm = nn.LayerNorm(hidden_size, eps=config.initialisation['layer_norm_eps']) + ffn_hidden_size = int(hidden_size * mlp_ratio) # size of the hidden layer in FFN + self.ffn = nn.Sequential( + nn.Linear(hidden_size, ffn_hidden_size), + nn.GELU(), # GELU activation function + nn.Dropout(dropout_rate), # Regularization with dropout + nn.Linear(ffn_hidden_size, hidden_size), + nn.Dropout(dropout_rate), # Regularization with dropout + ) + + def forward(self, hidden_states): + # Self-attention part + attention_output = self.self_attention(hidden_states) + attention_output = self.attention_output_dropout(attention_output) + + # Adding the residual connection, followed by normalization + attention_output = self.attention_output_layer_norm(attention_output + hidden_states) + + # Feed-forward network (FFN) part + ffn_output = self.ffn(attention_output) + # Adding the residual connection, followed by normalization + ffn_output = self.ffn_output_layer_norm(ffn_output + attention_output) + + return ffn_output + + +class ConvolutionalEmbedding(nn.Module): + """ + Embed images via convolutional layers in CViT- replaces the typical token embedding in a standard transformer model. + """ + def __init__(self, config): + super(ConvolutionalEmbedding, self).__init__() + + # Convolutional layers configuration + self.conv_layers = nn.ModuleList() # List storing a sequence of convolutions + self.conv_norms = nn.ModuleList() + + # Extract configuration parameters + patch_sizes = config.patches['sizes'] + patch_strides = config.patches['strides'] + patch_padding = config.patches['padding'] + embed_dims = config.transformer['embed_dim'] + + # Calculate the number of convolutional layers based on the configuration list length + num_conv_layers = len(patch_sizes) + + for i in range(num_conv_layers): + # Extract individual configuration parameters for the current layer + kernel_size = patch_sizes[i] + stride = patch_strides[i] + padding = patch_padding[i] + out_channels = embed_dims[i] + + # Create layer norm for this layer + layer_norm = nn.LayerNorm(out_channels) + + # Determine the number of input channels for the current layer + in_channels = config.general['num_channels'] if i == 0 else embed_dims[i - 1] + + # Create the convolutional layer with the current configuration + conv_layer = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + + # Add the created layer and its corresponding layer norm to the module lists + self.conv_layers.append(conv_layer) + self.conv_norms.append(layer_norm) + + def forward(self, x): + # Pass input through convolutional layers + for conv_layer, layer_norm in zip(self.conv_layers, self.conv_norms): + x = conv_layer(x) + x = F.gelu(x) # Apply GELU activation + # x = layer_norm(x) # Apply layer normalization + + # Reshape tensor for compatibility with subsequent transformer layers + batch_size, embed_dim, height, width = x.size() + x = x.view(batch_size, embed_dim, -1).transpose(1, 2) # Flatten spatial dimensions and move embedding dimension + + return x + + +class ConvolutionalVisionTransformer(nn.Module): + """ + CViT integrates CNNs with transformers for image processing + """ + def __init__(self, config): + super(ConvolutionalVisionTransformer, self).__init__() + + #Initialise convolutional embedding + self.conv_embedding=ConvolutionalEmbedding(config) + + #Transformer blocks- considering different stages with various depths + self.transformer_stages = nn.ModuleList() + block_index = 0 #unified index for all blocks across stages + for stage_depth in config.transformer['depth']: + stage_layers=nn.ModuleList() + for _ in range(stage_depth): + transformer_block = TransformerBlock(config, block_index) + stage_layers.append(transformer_block) + block_index += 1 + + self.transformer_stages.append(stage_layers) + + self.final_layer_norm = nn.LayerNorm(config.transformer["hidden_size"], eps=config.initialisation["layer_norm_eps"]) + #classifier head + self.classifier = nn.Sequential( + nn.Linear(config.transformer['hidden_size'], config.num_classes), + nn.Softmax(dim=1) # Apply softmax activation + ) + + def forward(self, x): + #Pass input through convolutional embedding layer + x = self.conv_embedding(x) + + #Propogate output sequentially through each stage + for stage in self.transformer_stages: + for transformer_block in stage: + x = transformer_block(x) + x = self.final_layer_norm(x) + + #Flatten representation at token level + x=x[:, 0] + + #Pass through classification head + logits = self.classifier(x) + return logits diff --git a/predict.py b/predict.py new file mode 100644 index 000000000..8e447c880 --- /dev/null +++ b/predict.py @@ -0,0 +1,47 @@ +import os +import argparse +import torch +from torchvision import transforms +from PIL import Image +from modules import ConvolutionalVisionTransformer, CViTConfig +from dataset import load_test_data,configuration + +def predict_image(image_path,model_path): + """ + Predict the class of an image using a pre-trained CViT model. An image is read and preprocessed from specified path, before feeding it into CViT model. The model + processes the image to determine the class. Function returns the class index predicted by the model. + """ + # Define a transform to preprocess the test image (adjust as needed) + transform = transforms.Compose([ + transforms.Resize((299,299)), + transforms.ToTensor(), + ]) + + # Load the model and set it to evaluation mode + config_params_dict=configuration() + config = CViTConfig(config_params_dict) + model = ConvolutionalVisionTransformer(config) + + # Load the trained model weights + model.load_state_dict(torch.load(model_path)) + model.eval() + + # Read and preprocess the test image + image = Image.open(image_path).convert("RGB") + image = transform(image).unsqueeze(0) # Add a batch dimension + + with torch.no_grad(): + output = model(image) + + # Process the model's output to get the predicted class + _, predicted_class = torch.max(output, 1) + + return predicted_class.item() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Image Prediction with CvT Model") + parser.add_argument("image_path", type=str, help="Path to the image file for prediction") + args = parser.parse_args() + model_save_path = "models.pth" + predicted_class = predict_image(args.image_path,model_save_path) + print(f"Predicted Class: {predicted_class}") diff --git a/train.py b/train.py new file mode 100644 index 000000000..19506245f --- /dev/null +++ b/train.py @@ -0,0 +1,204 @@ +import torch +import os +import argparse +import matplotlib.pyplot as plt +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from dataset import load_data, create_data_loaders,load_test_data +from dataset import ADNC_Dataset, get_image_paths_from_directory, extract_patient_id +from torchvision import transforms +from torch.utils.data import DataLoader +from modules import ConvolutionalVisionTransformer, CViTConfig,configuration + +def train_model(train_images_paths_AD, train_images_paths_NC, batch_size, model_save_path, num_epochs=10, learning_rate=0.001, num_classes=2, plot_path='plot'): + """ + Trains CViT for image classification, designed for distinguishing between two classes, (AD) and (NC). The function encompasses the full pipeline from loading the data, + training the model through specified epochs, and validating the model's performance, to saving the trained model and visualizing the training process statistics. + """ + # Load data + train_images_AD, train_images_NC, val_images_AD, val_images_NC = load_data(train_images_paths_AD, train_images_paths_NC) + + # Create data loaders + train_dataloader, val_dataloader = create_data_loaders(train_images_AD, train_images_NC, val_images_AD, val_images_NC, batch_size) + + config_params_dict=configuration() + config = CViTConfig(config_params_dict) + model = ConvolutionalVisionTransformer(config) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + # Lists to store training and validation statistics for each epoch + train_losses = [] + train_accuracies = [] + val_losses = [] + val_accuracies = [] + for epoch in range(num_epochs): + # Training phase + model.train() + total_loss = 0.0 + correct_predictions = 0 + + for images, labels in train_dataloader: + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + print(loss) + print(outputs) + print(labels) + loss.backward() + optimizer.step() + + total_loss += loss.item() + _, predicted = torch.max(outputs, 1) + print(predicted) + correct_predictions += (predicted == labels).sum().item() + + # Calculate average loss and accuracy for the epoch + average_loss = total_loss / len(train_dataloader.dataset) + accuracy = correct_predictions / len(train_dataloader.dataset) + print(f'Epoch [{epoch + 1}/{num_epochs}] Train Loss: {average_loss:.4f} Train Accuracy: {accuracy:.4f}') + + # Validation phase + model.eval() + val_total_loss = 0.0 + val_correct_predictions = 0 + for val_images, val_labels in val_dataloader: + with torch.no_grad(): + val_outputs = model(val_images) + val_loss = criterion(val_outputs, val_labels) + val_total_loss += val_loss.item() + _, val_predicted = torch.max(val_outputs, 1) + val_correct_predictions += (val_predicted == val_labels).sum().item() + + val_average_loss = val_total_loss / len(val_dataloader.dataset) + val_accuracy = val_correct_predictions / len(val_dataloader.dataset) + print(f'Epoch [{epoch + 1}/{num_epochs}] Validation Loss: {val_average_loss:.4f} Validation Accuracy: {val_accuracy:.4f}') + # Save model and training stats in each epoch + + train_losses.append(average_loss) + train_accuracies.append(accuracy) + val_losses.append(val_average_loss) + val_accuracies.append(val_accuracy) + + print('Training completed.') + + # You can also save the training statistics for future analysis + training_stats = { + 'train_losses': train_losses, + 'train_accuracies': train_accuracies, + 'val_losses': val_losses, + 'val_accuracies': val_accuracies + } + # Define the number of epochs + num_epochs = len(training_stats['train_losses']) + # Check if the folder exists before creating it + if not os.path.exists(plot_path): + try: + os.makedirs(plot_path) + print("Plot folder created successfully.") + except OSError as e: + print("Failed to create the path folder:", e) + else: + print("Plot folder already exists.") + # Create a list of epoch numbers for x-axis + epochs = list(range(1, num_epochs + 1)) + + # Get training and validation losses and accuracies + train_losses = training_stats['train_losses'] + val_losses = training_stats['val_losses'] + train_accuracies = training_stats['train_accuracies'] + val_accuracies = training_stats['val_accuracies'] + + # Create subplots for loss and accuracy + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8)) + + # Plot loss + ax1.plot(epochs, train_losses, label='Train Loss', marker='o') + ax1.plot(epochs, val_losses, label='Validation Loss', marker='o') + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Loss') + ax1.set_title('Training and Validation Losses') + ax1.legend() + + # Plot accuracy + ax2.plot(epochs, train_accuracies, label='Train Accuracy', marker='o') + ax2.plot(epochs, val_accuracies, label='Validation Accuracy', marker='o') + ax2.set_xlabel('Epoch') + ax2.set_ylabel('Accuracy') + ax2.set_title('Training and Validation Accuracies') + ax2.legend() + + plt.tight_layout() + plt.savefig(f'{plot_path}/plot.png') + + torch.save(model.state_dict(), f'{model_save_path}') + + +def test_model(model_path, test_imagesAD_path, test_images_nc_path, batch_size): + """ + Evaluate the performance of trained CViT model on a test dataset. The pretrained model and test dataset is loaded, and the model's + classification accuracy on the test data is evaluated. + """ + # Load test data + test_images_AD, test_images_NC = load_test_data(test_imagesAD_path, test_images_nc_path) + + # Create data loader for testing + data_transforms = { + 'test': transforms.Compose([ + transforms.Resize(299), + transforms.CenterCrop(299), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + } + + test_dataset = ADNC_Dataset(test_images_AD, test_images_NC, transform=data_transforms['test']) + test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + config_params_dict=configuration() + config = CViTConfig(config_params_dict) + model = ConvolutionalVisionTransformer(config) + + # Load the trained model weights + model.load_state_dict(torch.load(model_path)) + model.eval() + + # Define a function to evaluate the model on the test set + def evaluate_model(model, dataloader): + model.eval() + correct_predictions = 0 + total_samples = 0 + + for images, labels in dataloader: + with torch.no_grad(): + outputs = model(images) + _, predicted = torch.max(outputs, 1) + correct_predictions += (predicted == labels).sum().item() + total_samples += labels.size(0) + + accuracy = correct_predictions / len(dataloader.dataset) + return accuracy + + # Evaluate the model on the test set + test_accuracy = evaluate_model(model, test_dataloader) + print(f'Test Accuracy: {test_accuracy:.4f}') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Train a model with specified parameters") + + parser.add_argument("--epoch", type=int, default=2, help="Number of epochs to train (default: 2)") + parser.add_argument("--batch_size", type=int, default=2, help="Batch size for training (default: 2)") + + args = parser.parse_args() + + train_images_paths_AD = "AD_NC/train/AD" + train_images_paths_NC = "AD_NC/train/NC" + test_images_paths_AD = "AD_NC/test/AD" + test_images_paths_NC = "AD_NC/test/NC" + model_save_path = "models.pth" + save_plot = 'plot' + learning_rate = 0.001 + num_classes = 2 + + train_model(train_images_paths_AD, train_images_paths_NC, args.batch_size, model_save_path, args.epoch, learning_rate, num_classes, plot_path=save_plot) + test_model(model_save_path, test_images_paths_AD, test_images_paths_NC, args.batch_size)