diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..b8ec4838d --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +HipMRI_study_keras_slices_data +_pycache_ \ No newline at end of file diff --git a/LICENSE b/LICENSE index 261eeb9e9..4098b01b7 100644 --- a/LICENSE +++ b/LICENSE @@ -1,3 +1,206 @@ +<<<<<<< HEAD + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +======= Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -199,3 +402,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +>>>>>>> f0b762688b9a777aad0204f369487ca412f1fc5f diff --git a/README.md b/README.md index 4a064f841..3cb9e275f 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,71 @@ -# Pattern Analysis -Pattern Analysis of various datasets by COMP3710 students at the University of Queensland. - -We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. - -This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. - -The library includes the following implemented in Tensorflow: -* fractals -* recognition problems - -In the recognition folder, you will find many recognition problems solved including: -* OASIS brain segmentation -* Classification -etc. +# Pattern Analysis +Pattern Analysis of various datasets by COMP3710 students in 2024 at the University of Queensland. + +We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. + +This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. + +The library includes the following implemented in Tensorflow: +* fractals +* recognition problems + +In the recognition folder, you will find many recognition problems solved including: +* segmentation +* classification +* graph neural networks +* StyleGAN +* Stable diffusion +* transformers +etc. + +# VQ-VAE for Medical Image Generation +## Description +This repository implements a Vector Quantized Variational Autoencoder (VQ-VAE) model to generate medical images from 2D slices of prostate MRI scans. The VQ-VAE is a generative model that combines the principles of variational inference with vector quantization, allowing for a discrete representation that improves the quality of image reconstructions. The goal of this model is to map MRI scans of Prostate Cancer into a lower-dimensional latent space and decode the quantized representation back to the original image, providing a method for generating clear MRI images with improved similarity to real medical data. + +## Problem +The challenge this VQ-VAE addresses is generating clear medical images from noisy data while maintaining high similarity to the original MRI scans. Specifically, the model is trained on the HipMRI Study dataset containing 2D prostate MRI slices and attempts to reconstruct these images with a focus on structural similarity. + +## How It Works +The VQ-VAE model works by encoding input images, which is then quantized to discrete vectors. +This quantized space is then decoded to reconstruct the original image. The key components include: + +1. Encoder: Converts input images into a lower-dimensional latent representation using convolutional layers. +2. Vector Quantizer: Discretizes the latent space by mapping the continuous latent vectors to the nearest embedding in the codebook. +3. Decoder: Reconstructs the input image from the quantized latent vectors using transposed convolution layers. +4. Loss Function: The model optimizes a combination of reconstruction loss (MSE) and a commitment loss to ensure the latent vectors stay close to the embeddings. + +## A high-level flow: +* Input MRI image -> Encoder -> Latent Space +* Latent Space -> Vector Quantizer -> Quantized Latent Space +* Quantized Latent Space -> Decoder -> Reconstructed MRI image + +## Dependencies +To install all dependecies, run: +pip install torch torchvision nibabel matplotlib tqdm pathlib + +## Usage +1. Preprocessing: The .nii MRI image files are sliced into 2D images and resized. Normalization to the range [0, 1] is applied +2. Training: The dataset is split into 70% for training, 15% for validation, and 15% for testing. This split is chosen to ensure a sufficiently large training set while keeping a balanced validation and test set for hyperparameter tuning and performance evaluation. +3. To Run the training on current variables and epochs, while saving the model to reconstruct the images later +* > python train.py -save +4. To reconstruct the images after training and save the reconstructed images (top 5) +* > python predict.py -save + +## Results +### Validation SSIM +The model achieves a structural similarity index (SSIM) of 0.75 averaged over 5 images, indicating good image quality and resemblence of the input MRI Scans + +## Working Principles +### Algorithm: +* The VQ-VAE model solves the problem of medical image generation by learning discrete latent representations and creates high-quality MRI reconstructions +### Comments: +* All key components and methods in the modules.py file and all other files are well-commented and structures to explain the core logic behind the model +### Formatting: +* The README is properly formatted using GitHub markdown + +## Examples: +* Examples of the model's reconstructed images can be found at ./results + +# Author +Harrison Cleland, 2024 +47433386 diff --git a/README.pdf b/README.pdf new file mode 100644 index 000000000..c175cc43a Binary files /dev/null and b/README.pdf differ diff --git a/dataset.py b/dataset.py new file mode 100644 index 000000000..cd2cd698e --- /dev/null +++ b/dataset.py @@ -0,0 +1,70 @@ +import torch +import numpy as np +import nibabel as nib +from tqdm import tqdm + +# Set device for PyTorch +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Convert array to one-hot encoding +def to_channels(arr: np.ndarray, dtype=np.uint8) -> np.ndarray: + unique_classes = np.unique(arr) # Find unique classes + one_hot = np.zeros(arr.shape + (len(unique_classes),), dtype=dtype) # Initialize one-hot array + for c in unique_classes: + c = int(c) + one_hot[..., c:c + 1][arr == c] = 1 # Set one-hot encoding + + return one_hot + +# Load 2D medical image data +def load_data_2D(imageNames, normImage=False, categorical=False, dtype=np.float32, + getAffines=False, early_stop=False): + ''' + Load medical image data from provided filenames. + + normImage: bool (normalize image to 0.0 - 1.0) + early_stop: stop loading prematurely for quick tests + ''' + + affines = [] # Store affine transformations + + num = len(imageNames) # Number of images + print("Length of Images: ", num) + first_case = nib.load(imageNames[0]).get_fdata(caching='unchanged') # Load first image + if len(first_case.shape) == 3: + first_case = first_case[:, :, 0] # Take first slice if 3D + if categorical: + first_case = to_channels(first_case, dtype=dtype) # Convert to one-hot if categorical + rows, cols, channels = first_case.shape + images = np.zeros((num, rows, cols, channels), dtype=dtype) # Pre-allocate images + else: + rows, cols = first_case.shape + images = np.zeros((num, rows, cols), dtype=dtype) # Pre-allocate images + + # Load each image + for i, inName in enumerate(tqdm(imageNames)): + try: + niftiImage = nib.load(inName) # Load NIfTI image + inImage = niftiImage.get_fdata(caching='unchanged') # Get image data + affine = niftiImage.affine # Get affine transformation + if len(inImage.shape) == 3: + inImage = inImage[:, :, 0] # Take first slice if 3D + inImage = inImage.astype(dtype) # Convert to specified dtype + if normImage: + inImage = (inImage - inImage.mean()) / inImage.std() # Normalize image + if categorical: + inImage = to_channels(inImage, dtype=dtype) # Convert to one-hot if categorical + images[i, :, :, :] = inImage # Store in images array + else: + images[i, :, :] = inImage # Store in images array + + affines.append(affine) # Store affine + if i > 20 and early_stop: + break # Early stop if set + except: + print("Error occurred on image: ", i, inName) # Error handling + + if getAffines: + return images, affines # Return images and affines if requested + else: + return images # Return only images \ No newline at end of file diff --git a/modules.py b/modules.py new file mode 100644 index 000000000..8aba8300d --- /dev/null +++ b/modules.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class Decoder(nn.Module): + """ + This is the p_phi (x|z) network. Given a latent sample z p_phi + maps back to the original space z -> x. + """ + + def __init__(self, input_dim: int, hidden_dim: int, num_res_layers: int, res_hidden_dim: int): + super(Decoder, self).__init__() + kernel_size = 4 + stride = 2 + + self.inverse_conv_stack = nn.Sequential( + nn.ConvTranspose2d(input_dim, hidden_dim, kernel_size=kernel_size-1, stride=stride-1, padding=1), + ResidualStack(hidden_dim, hidden_dim, res_hidden_dim, num_res_layers), + nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, kernel_size=kernel_size, stride=stride, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(hidden_dim // 2, 3, kernel_size=kernel_size, stride=stride, padding=1) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.inverse_conv_stack(x) + +class Encoder(nn.Module): + def __init__(self, in_channels: int, hidden_dim: int, num_res_layers: int, res_hidden_dim: int): + super(Encoder, self).__init__() + self.conv_stack = nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + ) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv_stack(x) + +class VectorQuantizer(nn.Module): + """ + Discretization bottleneck part of the VQ-VAE. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, beta: float): + super(VectorQuantizer, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.beta = beta + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + self.embedding.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings) + + def forward(self, z): + """ + Maps encoder output z to a discrete one-hot vector. + """ + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.embedding_dim) + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.num_embeddings).to(device) + min_encodings.scatter_(1, min_encoding_indices, 1) + + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + + # calculate loss + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + z_q = z + (z_q - z).detach() + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return loss, z_q, perplexity, min_encodings, min_encoding_indices + +class ResidualLayer(nn.Module): + """ + One residual layer. + """ + + def __init__(self, in_dim, h_dim, res_h_dim): + super(ResidualLayer, self).__init__() + self.res_block = nn.Sequential( + nn.ReLU(True), + nn.Conv2d(in_dim, res_h_dim, kernel_size=3, + stride=1, padding=1, bias=False), + nn.ReLU(True), + nn.Conv2d(res_h_dim, h_dim, kernel_size=1, + stride=1, bias=False) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.res_block(x) + + +class ResidualStack(nn.Module): + """ A stack of residual layers. """ + + def __init__(self, in_dim: int, hidden_dim: int, res_hidden_dim: int, num_res_layers: int): + super(ResidualStack, self).__init__() + self.stack = nn.ModuleList([ResidualLayer(in_dim, hidden_dim, res_hidden_dim) for _ in range(num_res_layers)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.stack: + x = layer(x) + return F.relu(x) + +class VQVAE(nn.Module): + def __init__(self, hidden_dim: int, res_hidden_dim: int, num_res_layers: int, + num_embeddings: int, embedding_dim: int, beta: float, save_img_embedding_map: bool = False, input_channels: int = 1): + super(VQVAE, self).__init__() + + self.encoder = Encoder(input_channels, hidden_dim, num_res_layers, res_hidden_dim) + self.pre_quantization_conv = nn.Conv2d(hidden_dim, embedding_dim, kernel_size=1, stride=1) + + self.vector_quantization = VectorQuantizer(num_embeddings, embedding_dim, beta) + + self.decoder = Decoder(embedding_dim, hidden_dim, num_res_layers, res_hidden_dim) + + if save_img_embedding_map: + self.img_to_embedding_map = {i: [] for i in range(num_embeddings)} + else: + self.img_to_embedding_map = None + + def forward(self, x: torch.Tensor, verbose: bool = False): + z_e = self.encoder(x) + z_e = self.pre_quantization_conv(z_e) + embedding_loss, z_q, perplexity, _, _ = self.vector_quantization(z_e) + x_hat = self.decoder(z_q) + + if verbose: + print('Original data shape:', x.shape) + print('Encoded data shape:', z_e.shape) + print('Reconstructed data shape:', x_hat.shape) + + return embedding_loss, x_hat, perplexity diff --git a/predict.py b/predict.py new file mode 100644 index 000000000..130fa2bef --- /dev/null +++ b/predict.py @@ -0,0 +1,96 @@ +# Import necessary libraries and modules +import torch +import numpy as np +import argparse +import utils +import matplotlib.pyplot as plt +from modules import VQVAE +import os +from skimage.metrics import structural_similarity as ssim +from utils import predict_and_reconstruct +import dataset + +# Define command-line arguments for dataset directory and model save path +parser = argparse.ArgumentParser() +epochs = 100 +learning_rate = 1e-3 +batch_size = 16 +weight_decay = 1e-5 + +# Model architecture parameters +n_hiddens = 512 +n_residual_hiddens = 512 +n_residual_layers = 32 +embedding_dim = 512 +n_embeddings = 1024 +beta = 0.1 + +# Dataset and model save path arguments, for easier readability and storage +# previous parameters were changes often, therefore variables were chosen to store them +parser.add_argument("--dataset_dir", type=str, default='HipMRI_study_keras_slices_data') +parser.add_argument("--save_path", type=str, default="vqvae_data.pth") +parser.add_argument("-save", action="store_true") + +# Parse command-line arguments and set device +args = parser.parse_args() +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Load model with specified checkpoint path +# had to include load_model function within predict.py since it used many variables +def load_model(model_path): + model = VQVAE(n_hiddens, n_residual_hiddens, n_residual_layers, + n_embeddings, embedding_dim, beta).to(device) + checkpoint = torch.load(model_path, map_location=device) + model.load_state_dict(checkpoint['model']) + return model + +def main(): + # Load test data from dataset directory + test_path = os.path.join(args.dataset_dir, 'keras_slices_test') + nii_files_test = [os.path.join(test_path, img) for img in os.listdir(test_path) if img.endswith(('.nii', '.nii.gz'))] + x_test = dataset.load_data_2D(nii_files_test, normImage=False, categorical=False) + x_test_tensor = torch.from_numpy(x_test).float().unsqueeze(1) # Add channel dimension + + # Create DataLoader for test data and load saved model + test_loader = torch.utils.data.DataLoader(x_test_tensor, batch_size=batch_size) + path = 'results/' + args.save_path + model = load_model(path) + + # Generate reconstructions and display comparison images + for original, reconstructed in predict_and_reconstruct(model, test_loader): + print(f"Original shape: {original.shape}, Reconstructed shape: {reconstructed.shape}") + + # Create a single figure for displaying all images horizontally + fig, axs = plt.subplots(5, 2, figsize=(15, 10)) # 5 rows, 2 columns + + # Display the first 5 original and reconstructed images + for i in range(min(5, len(original))): + # Original image + original_img = np.squeeze(original[i], axis=0) + axs[i, 0].imshow(original_img, cmap='gray') + axs[i, 0].set_title(f'Original #{i}') + axs[i, 0].axis('off') # Hide axis for clarity + + # Reconstructed image + if reconstructed[i].shape[0] == 1: + reconstructed_img = np.squeeze(reconstructed[i], axis=0) + else: + reconstructed_img = np.mean(reconstructed[i], axis=0) + + # Compute SSIM score and plot + ssim_score = ssim(original_img, reconstructed_img, data_range=reconstructed_img.max() - reconstructed_img.min()) + axs[i, 1].imshow(reconstructed_img, cmap='gray') + axs[i, 1].set_title(f'#{i} SSIM Score: {ssim_score:.4f}') + axs[i, 1].axis('off') # Hide axis for clarity + + # Save each individual reconstructed image separately + if args.save: + plt.imsave(f"reconstructed_{i}.png", reconstructed_img, cmap='gray') # Save each reconstructed image + + plt.tight_layout() # Adjust layout for better spacing + plt.show() # Show all images together + + break # Stop after first batch for demo + +if __name__ == "__main__": + main() diff --git a/reconstructed_0.png b/reconstructed_0.png new file mode 100644 index 000000000..f8cee09f8 Binary files /dev/null and b/reconstructed_0.png differ diff --git a/reconstructed_1.png b/reconstructed_1.png new file mode 100644 index 000000000..ce72e8363 Binary files /dev/null and b/reconstructed_1.png differ diff --git a/reconstructed_2.png b/reconstructed_2.png new file mode 100644 index 000000000..49dd641f9 Binary files /dev/null and b/reconstructed_2.png differ diff --git a/reconstructed_3.png b/reconstructed_3.png new file mode 100644 index 000000000..5b8f6aa02 Binary files /dev/null and b/reconstructed_3.png differ diff --git a/reconstructed_4.png b/reconstructed_4.png new file mode 100644 index 000000000..c178db87b Binary files /dev/null and b/reconstructed_4.png differ diff --git a/results/reconstructed_0.png b/results/reconstructed_0.png new file mode 100644 index 000000000..f8cee09f8 Binary files /dev/null and b/results/reconstructed_0.png differ diff --git a/results/reconstructed_1.png b/results/reconstructed_1.png new file mode 100644 index 000000000..ce72e8363 Binary files /dev/null and b/results/reconstructed_1.png differ diff --git a/results/reconstructed_2.png b/results/reconstructed_2.png new file mode 100644 index 000000000..49dd641f9 Binary files /dev/null and b/results/reconstructed_2.png differ diff --git a/results/reconstructed_3.png b/results/reconstructed_3.png new file mode 100644 index 000000000..5b8f6aa02 Binary files /dev/null and b/results/reconstructed_3.png differ diff --git a/results/reconstructed_4.png b/results/reconstructed_4.png new file mode 100644 index 000000000..c178db87b Binary files /dev/null and b/results/reconstructed_4.png differ diff --git a/train.py b/train.py new file mode 100644 index 000000000..1886bbebd --- /dev/null +++ b/train.py @@ -0,0 +1,116 @@ +import torch +import torch.optim as optim +import argparse +import utils +import dataset +from modules import VQVAE +from tqdm import tqdm +import os +from skimage.metrics import structural_similarity as ssim +parser = argparse.ArgumentParser() + +""" +Hyperparameters +""" +epochs = 100 +learning_rate = 1e-3 +batch_size = 16 +weight_decay = 1e-5 + +# define model architecture +n_hiddens = 512 +n_residual_hiddens = 256 +n_residual_layers = 16 +embedding_dim = 512 +n_embeddings = 1024 +beta = 0.1 +categorical = False +normal_image = False + +# Add dataset and model save arguments for easier use and storage +parser.add_argument("--dataset_dir", type=str, default='HipMRI_study_keras_slices_data') +parser.add_argument("-save", action="store_true") + +args = parser.parse_args() # argument parser + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +if args.save: # specific argument to save the model so training can be run without affecting current saved model, testing purposes + print('Results will be saved in ./results/vqvae_data.pth') + +""" +Load data and define batch data loaders for .nii files +""" +# locate file location of training a validate data +train_path = os.path.join(args.dataset_dir, 'keras_slices_train') +validate_path = os.path.join(args.dataset_dir, 'keras_slices_validate') + +# retrieve all .nii files from the specified data location +nii_files_train = [os.path.join(train_path, img) for img in os.listdir(train_path) if img.endswith(('.nii', '.nii.gz'))] +nii_files_validate = [os.path.join(validate_path, img) for img in os.listdir(validate_path) if img.endswith(('.nii', '.nii.gz'))] + +# extract data from the .nii files +x_train = dataset.load_data_2D(nii_files_train, normImage=normal_image, categorical=categorical) +x_val = dataset.load_data_2D(nii_files_validate, normImage=normal_image, categorical=categorical) + +# convert data to tensors for use in torch, plus adding channel dimension +x_train_tensor = torch.from_numpy(x_train).float().unsqueeze(1) +x_val_tensor = torch.from_numpy(x_val).float().unsqueeze(1) + +# create dataloader for use in training +train_loader = torch.utils.data.DataLoader(x_train_tensor, batch_size=batch_size, shuffle=True) +val_loader = torch.utils.data.DataLoader(x_val_tensor, batch_size=batch_size) + +# initialise model +model = VQVAE(n_hiddens, n_residual_hiddens, + n_residual_layers, n_embeddings, embedding_dim, beta).to(device) + +""" +Set up optimizer and training loop +""" +optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=True) + +model.train() + +# results dictionary to store model data +results = { + 'n_updates': 0, + 'recon_errors': [], + 'loss_vals': [], + 'perplexities': [], +} + +def train(): + + for epoch in range(epochs): + # use of tqdm to create a progress bar of training batches for each epoch + for i, (x) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch')): + + x = x.to(device) # move batch to device (preferably gpu) + optimizer.zero_grad() # zero out gradients + + # forward pass + embedding_loss, x_hat, perplexity = model(x) + + # compute reconstruction loss + recon_loss = torch.mean((x_hat - x)**2) + loss = recon_loss + embedding_loss + + # backwards pass and optimization + loss.backward() + optimizer.step() + + # store results for logging and saving + results["recon_errors"].append(recon_loss.cpu().detach().numpy()) + results["perplexities"].append(perplexity.cpu().detach().numpy()) + results["loss_vals"].append(loss.cpu().detach().numpy()) + results["n_updates"] = i + + # save the model and data + if args.save: + hyperparameters = args.__dict__ + utils.save_model_and_results( + model, results, hyperparameters) + +if __name__ == "__main__": + train() diff --git a/utils.py b/utils.py new file mode 100644 index 000000000..18ba55dcf --- /dev/null +++ b/utils.py @@ -0,0 +1,43 @@ +import torch +import os +from pathlib import Path +import torch.nn.functional as F + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def compute_loss(original, reconstructed, embedding_loss, prior_logits, beta=0.25): + """ + Computes the loss for VQ-VAE with PixelCNN prior. + - x: Original input image + - x_hat: Reconstructed image + - embedding_loss: Loss from the vector quantization step + - prior_logits: Output from PixelCNN + """ + reconstruction_loss = F.mse_loss(reconstructed, original) + prior_loss = F.cross_entropy(prior_logits, original.long()) + total_loss = reconstruction_loss + beta * embedding_loss + prior_loss + + return total_loss + +def save_model_and_results(model, results, hyperparameters): + SAVE_MODEL_PATH = os.getcwd() + '\\results' + directory = Path(SAVE_MODEL_PATH) + if not directory.exists(): + os.makedirs(directory) + results_to_save = { + 'model': model.state_dict(), + 'results': results, + 'hyperparameters': hyperparameters + } + torch.save(results_to_save, + SAVE_MODEL_PATH + '\\vqvae_data.pth') + +# Generate predictions and reconstructions without updating gradients +def predict_and_reconstruct(model, data_loader): + model.eval() + with torch.no_grad(): + for x in data_loader: + x = x.to(device) + _, x_hat, _ = model(x) + yield x.cpu().numpy(), x_hat.cpu().numpy()