diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..81ca4a91e --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ + +*.pth +*.pyc diff --git a/recognition/vq-vae_s47036219/README.md b/recognition/vq-vae_s47036219/README.md new file mode 100644 index 000000000..cfd8b9897 --- /dev/null +++ b/recognition/vq-vae_s47036219/README.md @@ -0,0 +1,70 @@ +# VQ-VAE for the ADNI Dataset + +**Author**: Connor Armstrong (s4703621) + + +# Project: + +## The Vector Quantized Variational Autoencoder +The goal of this task was to implement a Vector Quantized Variational Autoencoder (henceforth referred to as a VQ-VAE). The VQ-VAE is an extension of a typical variational autoencoder that handles discrete latent representation learning - which is where the model learns to represent data in a form where the latent variables take on distinct discrete values, rather than a continuous range. This is done by the model passing the encoders output through a vector quantisation layer, mapping the continuous encodings to the closest vector in the embedding space. This makes the VQ-VAE very effective at managing discrete structured data and image reconstruction/generation. + + +## VQ-VAE Architecture +![VQ-VAE Structure](./vqvae_structure.jpg) + +As shown above, the VQ-VAE is comprised of a few important components: + +- **Encoder**: + The encoder takes in an input, represented by `x`, and compresses it into a continuous latent space resulting in `Z_e(x)`. + +- **Latent Predictor p(z)**: + This is not necessarily an actual module as in most VQ-VAE architectures, this isn't explicitly present. However, it is useful to think that the latent space has some underlying probability distribution `p(z)` which the model tries to capture or mimic. + +- **Nearest Neighbors & Codebook**: + One of the most important features of VQ-VAE is the use of a discrete codebook. Each entry in the codebook is a vector. The continuous output from the encoder (`Z_e(x)`) is mapped to the nearest vector in this codebook. This is represented by the table at the bottom. Each row is a unique vector in the codebook. The process of mapping `Z_e(x)` to the nearest codebook vector results in `Z_q(x)`, a quantized version of the encoder's output. + +- **Decoder**: + The decoder takes the quantized latent representation `Z_q(x)` and reconstructs the original input, producing `x'`. Ideally, `x'` should be a close approximation of the original input `x`. + +The use of a discrete codebook in the latent space (instead of a continuous one) allows the VQ-VAE to capture more complex data distributions with fewer latent variables. + + + +## VQ-VAE and the ADNI Dataset +The ADNI (Alzheimer’s Disease Neuroimaging Initiative) dataset is a collection of neuroimaging data, curated with the primary intent of studying Alzheimer's disease. In the context of the ADNI dataset, a VQ-VAE can be applied to condense complex brain scans into a more manageable, lower-dimensional, discrete latent space. By doing so, it can effectively capture meaningful patterns and structures inherent in the images. + + +## Details on the implementation + +The goal of this project was to: "Ceate a generative model of the ADNI brain dataset using a VQVA that has a “reasonably clear image” and a Structured Similarity (SSIM) of over 0.6" + +This implementation was relatively standard for this model. There exist other extensions that could be of a great use in this case, using a gan or other generative models in combination creates a powerful method to improve upon my implementation - but this is left forr other students with more time. + +# Usage: +**Please Note: Before running please add the directory to the train and test files for the dataset in 'train.py'** + +It is highly reccomended to run only the 'predict.py' file by calling 'python predict.py' while in the working directory. It is possible to run from the 'train.py' file as well, but this has implications with data leakage a I could not find a proper way to partition the test set. + +If all goes well, matplotlib outputs 4 images: the original and reconstructed brain with the highest ssim, and then the lowest ssim. + +# Data +This project uses the ADNI dataset (in the structure as seen on blackboard), where the training set is used to train the model, and the test folder is partitioned into a validation set and test set. + + +# Dependencies +| Dependency | Version | +|-------------|-------------| +| torch | 2.0.1+cu117 | +| torchvision | 0.15.2+cu117| +| matplotlib | 3.8.0 | + +# Output +As stated earlier, these are the images with the highest and lowest ssim scores: +![Output Image](./output.png) + +# References +The following sources inspired my implementation and were referenced in order to complete this project: +* Neural Discrete Representation Learning, Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu, 2017. https://arxiv.org/abs/1711.00937 +* Adni Brain Dataset, Thanks to https://adni.loni.usc.edu/ +* Misha Laskin, https://github.com/MishaLaskin/vqvae/tree/master +* Aurko Roy et al., Theory and Experiments on Vector Quantized Autoencoders, https://www.arxiv-vanity.com/papers/1805.11063/ \ No newline at end of file diff --git a/recognition/vq-vae_s47036219/dataset.py b/recognition/vq-vae_s47036219/dataset.py new file mode 100644 index 000000000..31bef8e9e --- /dev/null +++ b/recognition/vq-vae_s47036219/dataset.py @@ -0,0 +1,23 @@ +from torch.utils.data import DataLoader, random_split +from torchvision import datasets, transforms + +def get_dataloaders(train_string, test_validation_string, batch_size): + transform = transforms.Compose([ + transforms.Grayscale(num_output_channels=1), + transforms.Resize((64, 64)), + transforms.ToTensor(), + #transforms.Normalize(mean=[0.5], std=[0.5]) + ]) + train_dataset = datasets.ImageFolder(root=train_string, transform=transform) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + full_test_dataset = datasets.ImageFolder(root=test_validation_string, transform=transform) + test_size = int(0.3 * len(full_test_dataset)) + val_size = len(full_test_dataset) - test_size + + test_dataset, val_dataset = random_split(full_test_dataset, [test_size, val_size]) + + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + return train_loader, val_loader, test_loader diff --git a/recognition/vq-vae_s47036219/modules.py b/recognition/vq-vae_s47036219/modules.py new file mode 100644 index 000000000..76218b859 --- /dev/null +++ b/recognition/vq-vae_s47036219/modules.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn + + + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, intermediate_channels=None): + super(ResidualBlock, self).__init__() + + if not intermediate_channels: + intermediate_channels = in_channels // 2 + + self._residual_block = nn.Sequential( + nn.ReLU(), + nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.ReLU(), + nn.Conv2d(intermediate_channels, out_channels, kernel_size=1, stride=1, bias=False) + ) + + def forward(self, x): + return x + self._residual_block(x) + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + + self.layers = nn.Sequential( + nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + + nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + + ResidualBlock(64, 64), + ResidualBlock(64, 64) + ) + + def forward(self, x): + out = self.layers(x) + return out + +class VectorQuantizer(nn.Module): + def __init__(self, num_embeddings, embedding_dim): + super(VectorQuantizer, self).__init__() + + self.num_embeddings = num_embeddings # Save as an instance variable + self.embedding = nn.Embedding(self.num_embeddings, embedding_dim) + self.embedding.weight.data.uniform_(-1./self.num_embeddings, 1./self.num_embeddings) + + def forward(self, x): + batch_size, channels, height, width = x.shape + x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, channels) + + # Now x_flat is [batch_size * height * width, channels] + + # Calculate distances + distances = ((x_flat.unsqueeze(1) - self.embedding.weight.unsqueeze(0)) ** 2).sum(-1) + + # Find the closest embeddings + _, indices = distances.min(1) + encodings = torch.zeros_like(distances).scatter_(1, indices.unsqueeze(1), 1) + + # Quantize the input image + quantized = self.embedding(indices) + + # Reshape the quantized tensor to the same shape as the input + quantized = quantized.view(batch_size, height, width, channels).permute(0, 3, 1, 2) + + return quantized + +class Decoder(nn.Module): + def __init__(self): + super(Decoder, self).__init__() + + self.layers = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + + ResidualBlock(64, 64), + ResidualBlock(64, 64), + + nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + + nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1) + ) + + def forward(self, x): + return self.layers(x) + +class VQVAE(nn.Module): + def __init__(self, num_embeddings=512, embedding_dim=64): + super(VQVAE, self).__init__() + + self.encoder = Encoder() + self.conv1 = nn.Conv2d(64, embedding_dim, kernel_size=1, stride=1) + self.vector_quantizer = VectorQuantizer(num_embeddings, embedding_dim) + self.decoder = Decoder() + + def forward(self, x): + enc = self.encoder(x) + enc = self.conv1(enc) + quantized = self.vector_quantizer(enc) + + dec = self.decoder(quantized) + return dec + + +def ssim(img1, img2, C1=0.01**2, C2=0.03**2): + mu1 = img1.mean(dim=[2, 3], keepdim=True) + mu2 = img2.mean(dim=[2, 3], keepdim=True) + + sigma1_sq = (img1 - mu1).pow(2).mean(dim=[2, 3], keepdim=True) + sigma2_sq = (img2 - mu2).pow(2).mean(dim=[2, 3], keepdim=True) + sigma12 = ((img1 - mu1)*(img2 - mu2)).mean(dim=[2, 3], keepdim=True) + + ssim_n = (2*mu1*mu2 + C1) * (2*sigma12 + C2) + ssim_d = (mu1.pow(2) + mu2.pow(2) + C1) * (sigma1_sq + sigma2_sq + C2) + + ssim_val = ssim_n / ssim_d + + return ssim_val.mean() \ No newline at end of file diff --git a/recognition/vq-vae_s47036219/output.png b/recognition/vq-vae_s47036219/output.png new file mode 100644 index 000000000..c93e8362b Binary files /dev/null and b/recognition/vq-vae_s47036219/output.png differ diff --git a/recognition/vq-vae_s47036219/predict.py b/recognition/vq-vae_s47036219/predict.py new file mode 100644 index 000000000..278a2f0da --- /dev/null +++ b/recognition/vq-vae_s47036219/predict.py @@ -0,0 +1,102 @@ +import torch +from modules import VQVAE, ssim +from dataset import get_dataloaders +from train import SSIM_WEIGHT, L2_WEIGHT, BATCH_SIZE, train_new_model, path_to_training_folder, path_to_test_folder +import matplotlib +import matplotlib.pyplot as plt +import os + +def evaluate(test_loader): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = VQVAE().to(device) + model.load_state_dict(torch.load('vqvae.pth')) + model.eval() + print("loaded") + highest_ssim_val = float('-inf') # Initialize with negative infinity + lowest_ssim_val = float('inf') # Initialize with positive infinity + highest_ssim_img = None + highest_ssim_recon = None + lowest_ssim_img = None + lowest_ssim_recon = None + + val_losses = [] + ssim_sum = 0 # To keep track of sum of all SSIM values + total_images = 0 # To keep track of total number of images processed + + with torch.no_grad(): + for i, (img, _) in enumerate(test_loader): + img = img.to(device) + + # Validation forward pass + z = model.encoder(img) + z = model.conv1(z) + z_q = model.vector_quantizer(z) + recon = model.decoder(z_q) + + # Validation losses + l2_loss = ((recon - img) ** 2).sum() + ssim_loss = 1 - ssim(img, recon) + loss = L2_WEIGHT * l2_loss + SSIM_WEIGHT * ssim_loss + val_losses.append(loss.item()) + + # Calculate SSIM + ssim_val = ssim(img, recon).item() + ssim_sum += ssim_val # Add SSIM value to the sum + total_images += img.size(0) # Increase the total number of images processed + + #print(f'SSIM: {ssim_val}') # Output SSIM value + + # Update highest and lowest SSIM values and corresponding images + if ssim_val > highest_ssim_val: + highest_ssim_val = ssim_val + highest_ssim_img = img.cpu().numpy().squeeze(1) + highest_ssim_recon = recon.cpu().numpy().squeeze(1) + + if ssim_val < lowest_ssim_val: + lowest_ssim_val = ssim_val + lowest_ssim_img = img.cpu().numpy().squeeze(1) + lowest_ssim_recon = recon.cpu().numpy().squeeze(1) + + mean_ssim = ssim_sum / total_images + print(f'Mean SSIM: {mean_ssim}') # Output mean SSIM value + + # Output images with the highest and lowest SSIM values + plt.figure(figsize=(10, 5)) + + plt.subplot(2, 2, 1) + plt.title(f'Original Highest SSIM: {highest_ssim_val}') + plt.imshow(highest_ssim_img[0], cmap='gray') + + plt.subplot(2, 2, 2) + plt.title('Reconstructed') + plt.imshow(highest_ssim_recon[0], cmap='gray') + + plt.subplot(2, 2, 3) + plt.title(f'Original Lowest SSIM: {lowest_ssim_val}') + plt.imshow(lowest_ssim_img[0], cmap='gray') + + plt.subplot(2, 2, 4) + plt.title('Reconstructed') + plt.imshow(lowest_ssim_recon[0], cmap='gray') + + plt.tight_layout() + plt.show() + +def main(): + weight_file_path = "vqvae.pth" + + train, validate, test = get_dataloaders(path_to_training_folder, path_to_training_folder, BATCH_SIZE) + + if os.path.exists(weight_file_path): + print("Weights exist -> Evaluating Model...") + evaluate(test) + + else: + print(f"Weight file {weight_file_path} does not exist.") + print("Training model now...") + train_new_model(train, validate) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/recognition/vq-vae_s47036219/train.py b/recognition/vq-vae_s47036219/train.py new file mode 100644 index 000000000..f7eda4017 --- /dev/null +++ b/recognition/vq-vae_s47036219/train.py @@ -0,0 +1,105 @@ +# CONSTANTS AND HYPERPARAMETERS: + +import torch +import torch.optim as optim +from modules import VQVAE, ssim +from dataset import get_dataloaders + +path_to_training_folder = "C:/Users/Connor/Documents/comp3710/dataset/ADNI/AD_NC/train" +path_to_test_folder = "C:/Users/Connor/Documents/comp3710/dataset/ADNI/AD_NC/test" + +LEARNING_RATE = 1e-3 +BATCH_SIZE = 32 +NUM_EPOCHS = 40 # realistically stopped earlier by the validation set +CODEBOOK_SIZE = 512 + +# Weights for the loss functions +L2_WEIGHT = 0.05 +SSIM_WEIGHT = 1 + +# Constants for early stopping +PATIENCE = 12 +best_val_loss = float('inf') +counter = 0 + + +def train(vqvae, train_loader, validation_loader, device): + optimizer = optim.Adam(vqvae.parameters(), lr=LEARNING_RATE) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True) + # Training Loop + for epoch in range(NUM_EPOCHS): + for i, (img, _) in enumerate(train_loader): + vqvae.train() + img = img.to(device) # Move to device + + # Forward pass through the entire model + recon = vqvae(img) + + # Calculate L2 loss + l2_loss = ((recon - img) ** 2).sum() + + # Calculate SSIM loss + ssim_loss = 1 - ssim(img, recon) + + # Final Loss + loss = L2_WEIGHT * l2_loss + SSIM_WEIGHT * ssim_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Validation phase + val_losses = [] + vqvae.eval() + with torch.no_grad(): + for i, (img, _) in enumerate(validation_loader): + img = img.to(device) + + # Validation forward pass + recon = vqvae(img) # Changed this line to use the VQVAE model + + # Validation losses + l2_loss = ((recon - img) ** 2).sum() + ssim_loss = 1 - ssim(img, recon) + loss = L2_WEIGHT * l2_loss + SSIM_WEIGHT * ssim_loss + + val_losses.append(loss.item()) + + avg_val_loss = sum(val_losses) / len(val_losses) + print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Training Loss: {loss.item():.4f}, Validation Loss: {avg_val_loss:.4f}") + + # Update learning rate + scheduler.step(avg_val_loss) + # Print current learning rate + current_lr = optimizer.param_groups[0]['lr'] + print(f"Current Learning Rate: {current_lr}") + + # Early Stopping + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + counter = 0 # Reset counter when validation loss decreases + else: + counter += 1 + if counter >= PATIENCE: + print(f"Early stopping at epoch {epoch+1}") + break + torch.save(vqvae.state_dict(), 'vqvae.pth') + +def train_new_model(train_set, validation_set): # Called if weight didnt exist in the test set. + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("Running on: ", device) + model = VQVAE(CODEBOOK_SIZE).to(device) + model = train(model, train_set, validation_set, device) + +def main(): + print("WARNING: RUNNING FROM TRAIN FILE") + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("Running on: ", device) + train_loader, validation_loader, _ = get_dataloaders(path_to_training_folder, path_to_test_folder, BATCH_SIZE) + + model = VQVAE(CODEBOOK_SIZE).to(device) + model = train(model, train_loader, device) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/recognition/vq-vae_s47036219/vqvae_structure.jpg b/recognition/vq-vae_s47036219/vqvae_structure.jpg new file mode 100644 index 000000000..56d6f189b Binary files /dev/null and b/recognition/vq-vae_s47036219/vqvae_structure.jpg differ