diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..1ab3b3d98 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +recognition/MRI_TimothyTjipto/bin/lol.png +recognition/MRI_TimothyTjipto/bin/Note.txt +recognition/MRI_TimothyTjipto/bin/TEST.ipynb +recognition/MRI_TimothyTjipto/bin/test2.ipynb +recognition/MRI_TimothyTjipto/bin/Test3.ipynb +recognition/MRI_TimothyTjipto/bin/Tester3.ipynb +recognition/MRI_TimothyTjipto/bin/visualise_batch.png diff --git a/recognition/SiameseNetwork_s4653241/Images/Accuracy.png b/recognition/SiameseNetwork_s4653241/Images/Accuracy.png new file mode 100644 index 000000000..e88ec8c95 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/Accuracy.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/Iteration loss.png b/recognition/SiameseNetwork_s4653241/Images/Iteration loss.png new file mode 100644 index 000000000..041192365 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/Iteration loss.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/Loss.png b/recognition/SiameseNetwork_s4653241/Images/Loss.png new file mode 100644 index 000000000..804e37666 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/Loss.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/Siamese_Network.png b/recognition/SiameseNetwork_s4653241/Images/Siamese_Network.png new file mode 100644 index 000000000..b8de1a034 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/Siamese_Network.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/prediction/test1g.png b/recognition/SiameseNetwork_s4653241/Images/prediction/test1g.png new file mode 100644 index 000000000..53a25d5a0 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/prediction/test1g.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/prediction/test2g.png b/recognition/SiameseNetwork_s4653241/Images/prediction/test2g.png new file mode 100644 index 000000000..90548d058 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/prediction/test2g.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/prediction/test3g.png b/recognition/SiameseNetwork_s4653241/Images/prediction/test3g.png new file mode 100644 index 000000000..c8a239b38 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/prediction/test3g.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/prediction/test4g.png b/recognition/SiameseNetwork_s4653241/Images/prediction/test4g.png new file mode 100644 index 000000000..4f83f3cd4 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/prediction/test4g.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/prediction/test5g.png b/recognition/SiameseNetwork_s4653241/Images/prediction/test5g.png new file mode 100644 index 000000000..0ac0ef11c Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/prediction/test5g.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/prediction/test6g.png b/recognition/SiameseNetwork_s4653241/Images/prediction/test6g.png new file mode 100644 index 000000000..dd3fab074 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/prediction/test6g.png differ diff --git a/recognition/SiameseNetwork_s4653241/Images/visualise_batch.png b/recognition/SiameseNetwork_s4653241/Images/visualise_batch.png new file mode 100644 index 000000000..033c6f7cf Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/Images/visualise_batch.png differ diff --git a/recognition/SiameseNetwork_s4653241/README.md b/recognition/SiameseNetwork_s4653241/README.md new file mode 100644 index 000000000..0d5fc21bd --- /dev/null +++ b/recognition/SiameseNetwork_s4653241/README.md @@ -0,0 +1,98 @@ +# Siamese Network classifier to classify Alzheimer's disease + +[7]Project uses Siamese Network to predict the similarity between two images and classify normal or AD(Alzheimer's Disease). + +## Siamese network + +Siamese Network is a specialized neural network architectural which uses two identical subnetworks that shares parameters and weight. + +Use two images as inputs and produces a similarity score between the two. Then classify with a label. + +Popular with face verification, signature verification, and few-shot learning. + +![SiameseNetwork_Example](Images/Siamese_Network.png) + +Both uses the same identical Convolutional Neural Network(CNN). + +During Training, Siamese networks often uses pairs that are "similar" or "dissimilar". Network learns to minimize the distance for similar pairs and maximize it for dissimilar pairs. + +After processing the inputs, the final layer/differencing layer computes a distance metrics between the two outputs, often Euclidean distance. Similar items will have smaller distance between their output, while dissimilar items will have a larger distance. + +Main advantages of using Siamese Network is their ability to perform one-shot learning. The ability to recongize new classes or entities with little data. + +## ADNI brain dataset + +### 1. Data Preprocessing + +The ADNI brain dataset contains two classes, AD and NC in both Training and Test. All image has initial shape of (3,240,256) 256 x 240 (W x H) [batch_size,channel,height,width]. Loading the dataset, all images will be normalized.All Images are resized to (1,120,128) to increase training speed. + +It is then paired and labeled accourdingly if it is similar class or not. Pairing process randomly pics between 2 classes to pair up. In doing so doesn't overtrain the model. + +Batchloader is set to 16 to speed up the training process, can be change if needed. + +Figure shows a batch with labels stating if it's a Similar or Dissimilar pair. + +![visualise_pair](Images/visualise_batch.png) + +POS being Positive pair and NEG being Negative pair for visual purposes. + +### 2. Siamese Model + +The Siamese Model first part begins with the embedding where it transforms the input images into a continuous vector space. + +3 convolutional layers, 2 max-pooling layer and 2 dense layer with sigmoid activation function. + +Sigmoid activation for final layer as output is within specific range. + + + + +### 3. Training + +Contrastive Loss is used as a loss function as it is focus on learning the similarity or dissimilarity between pairs of the inputs. + +Optimizer is Adam with a learning rate of 0.00006 + +After 40 Epoch, + +![iteration_loss](Images/Iteration%20loss.png) + + +### 4. Testing + +Testing the trained model results, + +![loss](Images/Loss.png) +![accuracy](Images/Accuracy.png) + +### 5. Prediction + + +![predicton1](Images/prediction/test1g.png) ![prediction2](Images/prediction/test2g.png) ![prediction3](Images/prediction/test3g.png) + + + +## Code Discription +1. "dataset.py" contains Data loader for loading and preprocessing the dataset. + +2. "modules.py" contains Source code of the components of the model.Each component is implementated as a class or a function. + +3. "predict.py" contains to showexample usage of trained model. Print out any results and/ or provide visualisations where applicable. + +4. "train.py" contains the source code for training, validating, testing and saving the model. + - Change TRAIN_PATH to PATH of training dataset and set TRAINING_MODE to True if you want to use model for training. + - If use checkpoint of trained model to test edit CHECKPOINT PATH + + + +## **Dependencies** +1. Python 3.11.5 +2. External Libriaries: + - torch 2.01 + - matplotlib 3.8.0 + - torchvision 0.15.2 + - numpy 1.25.2 + + +## References +[1] Images of Achitecute of Siamese Neural Network https://www.latentview.com/blog/siamese-neural-network-a-face-recognition-case-study/ \ No newline at end of file diff --git a/recognition/SiameseNetwork_s4653241/__pycache__/dataset.cpython-310.pyc b/recognition/SiameseNetwork_s4653241/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 000000000..9dd4906e4 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/__pycache__/dataset.cpython-310.pyc differ diff --git a/recognition/SiameseNetwork_s4653241/__pycache__/modules.cpython-310.pyc b/recognition/SiameseNetwork_s4653241/__pycache__/modules.cpython-310.pyc new file mode 100644 index 000000000..35be25221 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/__pycache__/modules.cpython-310.pyc differ diff --git a/recognition/SiameseNetwork_s4653241/__pycache__/predict.cpython-310.pyc b/recognition/SiameseNetwork_s4653241/__pycache__/predict.cpython-310.pyc new file mode 100644 index 000000000..776afc988 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/__pycache__/predict.cpython-310.pyc differ diff --git a/recognition/SiameseNetwork_s4653241/__pycache__/train.cpython-310.pyc b/recognition/SiameseNetwork_s4653241/__pycache__/train.cpython-310.pyc new file mode 100644 index 000000000..337d8ab70 Binary files /dev/null and b/recognition/SiameseNetwork_s4653241/__pycache__/train.cpython-310.pyc differ diff --git a/recognition/SiameseNetwork_s4653241/dataset.py b/recognition/SiameseNetwork_s4653241/dataset.py new file mode 100644 index 000000000..c08cfd9db --- /dev/null +++ b/recognition/SiameseNetwork_s4653241/dataset.py @@ -0,0 +1,214 @@ +# Importing necessary libraries and modules +import torch +import numpy as np +import matplotlib.pyplot as plt +from torchvision import datasets,transforms +from torch.utils.data import DataLoader,ConcatDataset,Dataset,TensorDataset, Subset +from PIL import Image +import random + +def get_transform(): + """ + Returns a composed transform for preprocessing images. + Returns: + torchvision.transforms.Compose: A composed transform for preprocessing images. + """ + transform = transforms.Compose([ + transforms.Grayscale(num_output_channels=1), # Convert to grayscale with one channel + transforms.Resize((120,128)), # Resize to (120,128) + transforms.ToTensor(), + # You can add more transformations if needed + ]) + return transform + +def get_dataloader(dataset, batch_size = 16, shuffle=True): + """ + Returns a DataLoader for the given dataset. + + This function creates and returns a DataLoader for the provided dataset with the specified batch size and shuffling option. + + Parameters: + - dataset (Dataset): The dataset for which the DataLoader is to be created. + - batch_size (int, optional): The number of samples per batch. Default is 16. + - shuffle (bool, optional): Whether to shuffle the dataset before splitting into batches. Default is True. + + Returns: + - DataLoader: A DataLoader object for the given dataset with the specified parameters. + """ + v_dataloader = DataLoader(dataset, + shuffle=shuffle, + num_workers=1, + batch_size=batch_size) + return v_dataloader + + + +def visualise_1(dataset): + """ + Plots a random image in the dataset + + Args: + dataset (Dataset): Dataset which the random image will be picked + """ + img,lab = random.choice(dataset) + plt.title(lab) + plt.imshow(img) + plt.axis('off') + plt.savefig("visualise_1") + plt.show() + + +def visualise_batch(dataloader): + """ + Plots a batch of images from the dataloader + + Args: + dataloader (Dataloader): Dataset which a batch will be taken to be plotted. + """ + LABELS = ['POS','NEG'] + + example_batch = iter(dataloader) + images1,images2,labels = next(example_batch) + + plt.figure(figsize=(16,4)) # width x height + batch_size = len(images1) + for idx in range(batch_size): + + image1 = transforms.ToPILImage()(images1[idx]) + image2 = transforms.ToPILImage()(images2[idx]) + label = LABELS[int(labels[idx].item())] + + plt.subplot(2,batch_size,idx+1) + + plt.imshow(image1,cmap='gray') + plt.axis('off') + + plt.subplot(2,batch_size,idx+1+batch_size) + plt.imshow(image2,cmap='gray') + plt.title(label) + plt.axis('off') + + plt.savefig("visualise_batch") + plt.show() + + + +class SiameseNetworkDataset1(Dataset): + """ + A dataset class for creating pairs of images for Siamese networks. + + Args: + Dataset (torchvision.datasets.ImageFolder): A dataset object containing images and their labels. + transform (torchvision.transforms): A function/transform that takes in an image and returns a transformed version. + Default is None. + """ + def __init__(self,imageFolderDataset,transform=None): + + self.imageFolderDataset = imageFolderDataset + self.transform = transform + + def __getitem__(self,index): + """ + Returns a pair of images and a label indicating if they belong to the same class. + + Args: + index (int): Index (ignored) + + Returns: + tuple: A tuple containing two images and a label + """ + img0_tuple = random.choice(self.imageFolderDataset.imgs) + + #We need to approximately 50% of images to be in the same class + should_get_same_class = random.randint(0,1) + if should_get_same_class: + while True: + #Look untill the same class image is found + img1_tuple = random.choice(self.imageFolderDataset.imgs) + if img0_tuple[1] == img1_tuple[1]: + break + else: + + while True: + #Look untill a different class image is found + img1_tuple = random.choice(self.imageFolderDataset.imgs) + if img0_tuple[1] != img1_tuple[1]: + break + + img0 = Image.open(img0_tuple[0]) + img1 = Image.open(img1_tuple[0]) + + img0 = img0.convert("L") + img1 = img1.convert("L") + + if self.transform is not None: + img0 = self.transform(img0) + img1 = self.transform(img1) + + return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32)) + + def __len__(self): + return len(self.imageFolderDataset.imgs) + + +class SiameseNetworkDataset_test(Dataset): + """ + A test dataset class for creating pairs of images for Siamese networks. + + Args: + Dataset (torchvision.datasets.ImageFolder): A dataset object containing images and their labels. + transform (torchvision.transforms): A function/transform that takes in an image and returns a transformed version. + Default is None. + """ + def __init__(self,imageFolderDataset,transform=None): + self.imageFolderDataset = imageFolderDataset + self.transform = transform + + def __getitem__(self,index): + """ + Returns a pair of images, a label indicating if they belong to the same class and labels of images. + + Args: + index (int): Index (ignored in this implementation as images are chosen randomly). + + Returns: + tuple: A tuple containing two images, a label (1 if the images are from different classes, 0 otherwise) and 2 labels of the respective images. + """ + img0_tuple = random.choice(self.imageFolderDataset.imgs) + + #We need to approximately 50% of images to be in the same class + should_get_same_class = random.randint(0,1) + if should_get_same_class: + while True: + #Look untill the same class image is found + img1_tuple = random.choice(self.imageFolderDataset.imgs) + if img0_tuple[1] == img1_tuple[1]: + break + else: + + while True: + #Look untill a different class image is found + img1_tuple = random.choice(self.imageFolderDataset.imgs) + if img0_tuple[1] != img1_tuple[1]: + break + + img0 = Image.open(img0_tuple[0]) + img1 = Image.open(img1_tuple[0]) + + img0 = img0.convert("L") + img1 = img1.convert("L") + + if self.transform is not None: + img0 = self.transform(img0) + img1 = self.transform(img1) + + return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32)), img0_tuple[1],img1_tuple[1] + + def __len__(self): + """ + Returns the total number of images in the dataset. + + Returns: + int: Total number of images in the dataset. + """ + return len(self.imageFolderDataset.imgs) diff --git a/recognition/SiameseNetwork_s4653241/modules.py b/recognition/SiameseNetwork_s4653241/modules.py new file mode 100644 index 000000000..030fc74e2 --- /dev/null +++ b/recognition/SiameseNetwork_s4653241/modules.py @@ -0,0 +1,85 @@ +# Importing necessary libraries and modules +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SiameseNetwork(nn.Module): + """ + A Siamese Neural Network that takes in pair of images and returns vectors + for both images in the pair. The vectors are then used to determine the similarity + between the two images. + + Args: + nn (torch.nn.Module)): Base class for all neural network modules in PyTorch. + """ + + def __init__(self): + super(SiameseNetwork, self).__init__() + + # Setting up the Sequential of CNN Layers + self.cnn1 = nn.Sequential( + nn.Conv2d(1, 32, kernel_size=10,stride=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, stride=1), + + nn.Conv2d(32, 64, kernel_size=7, stride=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, stride=1), + + nn.Conv2d(64, 128, kernel_size=4,stride=1), + nn.ReLU(inplace=True) + ) + + + # Setting up the Fully Connected Layers + self.fc1 = nn.Sequential( + nn.Linear(128*108*100, 64), + nn.Sigmoid(), + nn.Linear(64,1), + nn.Sigmoid() + + + ) + + def forward_once(self, x): + # This function will be called for both images + # Its output is used to determine the similiarity + output = self.cnn1(x) + output = output.view(output.size()[0], -1) + output = self.fc1(output) + return output + + def forward(self, input1, input2): + # In this function we pass in both images and obtain both vectors + # which are returned + output1 = self.forward_once(input1) + output2 = self.forward_once(input2) + + return output1, output2 + + # Define the Contrastive Loss Function +class ContrastiveLoss(torch.nn.Module): + """ + Contrastive loss function. Computes the contrastive loss between pairs of samples based on their + distances and labels. + + Args: + margin (float): The margin value beyond which the loss will not incease. + It acts as a threshold to separate positive and negative pairs. + Default is 2.0. + """ + def __init__(self, margin=2.0): + super(ContrastiveLoss, self).__init__() + self.margin = margin + + def forward(self, output1, output2, label): + + euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True) + + loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) + + + return loss_contrastive + diff --git a/recognition/SiameseNetwork_s4653241/predict.py b/recognition/SiameseNetwork_s4653241/predict.py new file mode 100644 index 000000000..155d6c29c --- /dev/null +++ b/recognition/SiameseNetwork_s4653241/predict.py @@ -0,0 +1,90 @@ +# Importing necessary libraries and modules +import matplotlib.pyplot as plt +import torch +from torchvision import transforms + + +def show_plot(iteration,loss): + """ + Plots the loss values against iterations and saves the resulting graph. + + Args: + iteration (List[int]): A list of iteration numbers. + loss (List[float]): A list of loss values corresponding to each iteration. + """ + plt.clf() + plt.plot(iteration,loss) + plt.savefig("Iteration loss") + plt.show() + +def predict(model, input1, input2): + """ + Make predictions using a trained PyTorch model. + + Args: + - model: Trained PyTorch model. + - input_data: PyTorch tensor containing the input data. + + Returns: + - Predictions as a PyTorch tensor. + """ + # Set the model to evaluation mode + model.eval() + + # Disable gradient computation + with torch.no_grad(): + predictions = model(input1, input2) + + return predictions + +def classify_pair(score, threshold): + """ + Classify pairs of samples based on a threshold. + + Args: + - score: Score of Dissimilarity + - threshold: Decision boundary for classification. + + Returns: + - List of classifications (0 for dissimilar, 1 for similar). + """ + + classification = 1 if score < threshold else 0 + return classification + + +def visual_pred_dis(idx,x0,x1,x0label,x1label,euclidean_distance,predict_class): + """ + Visualizes the predictions by plotting the input images and their predicted dissimilarity. + + Args: + idx (int): _description_ + x0 (torch.Tensor): First input image tensor + x1 (torch.Tensor): Second input image tensor + x0label (int): Label of first image + x1label (int): Label of second image + euclidean_distance (float): Calculated euclidean distance between the embeddings of the two images. + predict_class (int): Predicted class (0 for 'Different', 1 for 'Same'). + """ + Prediction = ['Different', 'Same'] + + plt.clf() + plt.subplot(2, 8, 1) + plt.title(int(x0label)) + x0_pic = transforms.ToPILImage()(x0[0]) + plt.axis('off') + plt.imshow(x0_pic, cmap='gray') + + plt.subplot(2,8,2) + plt.title(f'Dissimilarity: {euclidean_distance.item():.2f}\nClass predicted: {Prediction[predict_class]} ') + plt.axis('off') + + + + plt.subplot(2, 8, 3) + plt.title(int(x1label)) + x1_pic = transforms.ToPILImage()(x1[0]) + plt.axis('off') + plt.imshow(x1_pic, cmap='gray') + + plt.savefig(f'/home/Student/s4653241/MRI/Test_pic/test{idx}') \ No newline at end of file diff --git a/recognition/SiameseNetwork_s4653241/train.py b/recognition/SiameseNetwork_s4653241/train.py new file mode 100644 index 000000000..19b433859 --- /dev/null +++ b/recognition/SiameseNetwork_s4653241/train.py @@ -0,0 +1,196 @@ +# Importing necessary libraries and modules +import torch +from torchvision import datasets +from torch import optim +import torch.nn.functional as F + +# Importing custom modules +import dataset +from dataset import SiameseNetworkDataset1,SiameseNetworkDataset_test +from modules import SiameseNetwork, ContrastiveLoss +from predict import show_plot +import predict + +# Paths of data +TRAIN_PATH = "/home/Student/s4653241/AD_NC/train" +TEST_PATH = "/home/Student/s4653241/AD_NC/test" + +INPUT_SHAPE= (120, 128) # SIZE OF IMAGE 256 X 240 +BATCH_SIZE = 16 # Batch Size for DataLoader + +TRAINING_MODE = True # Training mode +EPOCH_RANGE = 61 # Size of the Training Epoch +CHECKPOINT_TRAINING = False # Use Checkpoint and continue Training +LOAD_CHECKPOINT_TRAINING = "/home/Student/s4653241/MRI/Training_Epoch/Epoch_40.pth" +SAVE_EPOCH = False +EPOCH_SAVE__CHECKPOINT = 60 # Saves every 60 Epoch + +TEST_MODE = True # For Testing +CHECKPOINT = "/home/Student/s4653241/MRI/Training_Epoch/Epoch_60.pth" # Test the checkpoint you want +TEST_RANGE = 500 # Testing size +THRESHOLD = 0.5 # Threshold Number +VISUALISE = False # Print out Error pics DEBUGGING TOOL for now + + +def load_checkpoint(path): + """ + Load a model and its parameters from a checkpoint. + + This function loads a SiameseNetwork model along with its optimizer state, epoch, counter, loss, and iteration + from a given checkpoint path. The model is moved to the CUDA device if available. + + Args: + path (str): Path to the checkpoint file. + + Returns: + tuple: A tuple containing the following elements: + - model (SiameseNetwork): The loaded SiameseNetwork model. + - optimizer (optim.Adam): The optimizer with its state loaded. + - epoch (int): The epoch at which the checkpoint was saved. + - counter (list): A list of counters indicating the progress of training. + - loss (list): A list of loss values recorded during training. + - iteration (int): The iteration number at which the checkpoint was saved. + """ + model = SiameseNetwork().cuda() + optimizer = optim.Adam(model.parameters(), lr = 0.00006) + + device = torch.device("cuda") + checkpoint = torch.load(LOAD_CHECKPOINT_TRAINING) + model.load_state_dict(checkpoint['model_state_dict'], strict=False) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + epoch = checkpoint['epoch'] + 1 + counter = checkpoint['counter'] + loss = checkpoint['loss'] + iteration = checkpoint['iteration'] + model.train() + return model,optimizer,epoch,counter,loss,iteration + +def main(): + """ + Main function to train and test a Siamese Network. + """ + training_transform = dataset.get_transform() + raw_dataset = datasets.ImageFolder(root=TRAIN_PATH) + siamese_dataset = SiameseNetworkDataset1(raw_dataset, training_transform ) + + training_dataloader = dataset.get_dataloader(siamese_dataset,BATCH_SIZE,True) + + dataset.visualise_batch(training_dataloader) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + net = SiameseNetwork().cuda() + criterion = ContrastiveLoss() + optimizer = optim.Adam(net.parameters(), lr = 0.00006) + current_epoch = 1 + + counter = [] + loss_history = [] + iteration_number= 0 + + if CHECKPOINT_TRAINING: + + net,optimizer,current_epoch,counter,loss_history,iteration_number=load_checkpoint(LOAD_CHECKPOINT_TRAINING) + + # Iterate throught the epochs + if TRAINING_MODE: + for epoch in range(current_epoch, EPOCH_RANGE): + + # Iterate over batches + for i, (img0, img1, label) in enumerate(training_dataloader, 0): + + # Send the images and labels to CUDA + img0, img1, label = img0.cuda(), img1.cuda(), label.cuda() + + # Zero the gradients + optimizer.zero_grad() + + # Pass in the two images into the network and obtain two outputs + output1, output2 = net(img0, img1) + + # Pass the outputs of the networks and label into the loss function + loss_contrastive = criterion(output1, output2, label) + + # Calculate the backpropagation + loss_contrastive.backward() + + # Optimize + optimizer.step() + + # Every 50 batches append loss + if i % 50 == 0 : + counter.append(iteration_number) + loss_history.append(loss_contrastive.item()) + iteration_number += 50 + + # Save Epoch to Checkpoint + if SAVE_EPOCH: + if epoch%EPOCH_SAVE__CHECKPOINT == 0: + checkpoint = { + 'epoch': epoch, + 'model_state_dict': net.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'counter':counter, + 'loss': loss_history, + 'iteration': iteration_number, + } + torch.save(checkpoint, f"/home/Student/s4653241/MRI/Training_Epoch/Epoch_{epoch}.pth") + + show_plot(counter, loss_history) + if TEST_MODE: + + # Load Checkpoint + if not TRAINING_MODE: + checkpoint = torch.load(CHECKPOINT) + model = SiameseNetwork().cuda() + model.load_state_dict(checkpoint['model_state_dict'], strict=False) + epoch = checkpoint['epoch'] + + + model.eval() + + raw_test_dataset = datasets.ImageFolder(root=TEST_PATH) + test_transform = dataset.get_transform() + test_siam = SiameseNetworkDataset_test(raw_test_dataset, test_transform) + test_dataloader = dataset.get_dataloader(test_siam,1,True) + + dataiter = iter(test_dataloader) + x0, _, _,x0label,_ = next(dataiter) + + postive_prediction = 0 + + # Test image + for i in range(TEST_RANGE): + + # Iterate over 10 images and test them with the first image (x0) + _, x1, label2,_,x1label = next(dataiter) + + + with torch.no_grad(): + output1, output2 = model(x0.cuda(), x1.cuda()) + euclidean_distance = F.pairwise_distance(output1, output2) + + predict_class = predict.classify_pair(euclidean_distance.item(),THRESHOLD) # Threshold + + if predict_class == 1 and int(x0label) == int(x1label): + + postive_prediction+=1 + + + if predict_class != 1 and int(x0label) != int(x1label): + + postive_prediction+=1 + + + if VISUALISE: + + predict.visual_pred_dis(i,x0,x1,x0label,x1label,euclidean_distance,predict_class) + + + Accuracy = postive_prediction/TEST_RANGE + print(f'Using Checkpoint: {CHECKPOINT}\nAccuracy: {Accuracy}\nNo. of Positive Matches: {postive_prediction}\nNo. of Test: {TEST_RANGE}') + + + +if __name__ == '__main__': + main() \ No newline at end of file