diff --git a/README.md b/README.md index 4a064f841..0e4cc24c2 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,107 @@ -# Pattern Analysis -Pattern Analysis of various datasets by COMP3710 students at the University of Queensland. +**Introduction:** -We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. +In this report, a VQVAE will be developed in Pytorch to generate new images using the OASIS dataset (grayscale MRI images of the brain). +A VQVAE (Vector-Quantised Variational Autoencoder) is a neural network architecture, which learns discrete latent space representations of the input. It reconstructs the original input from these representations. Discretising the latent space gives VQVAEs many advantages over the conventional VAEs, making the representations more compact and interpretable. Using an autoregressive neural network to learn the space of discrete representations, new plausible outputs are constructed by sampling from this learned space. -This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. +**Hyperparameters:** -The library includes the following implemented in Tensorflow: -* fractals -* recognition problems +The hyperparameters were tuned throughout the report. The learning rate which yielded the best outcome for the VQVAE was found to be 5e-4, and the batch size, 128. Scheduling was considered however wasn’t necessary for the models. +The number of embeddings was chosen to be 256, and the initial embedding dimension as 32. The embedding dimension, however, was increased, as the low embedding dimension was found to be insufficient in capturing the fine details of the dataset, bottlenecking the model. This caused the loss plateauing at roughly 0.3, with the model not yet achieving the desired sharpness due to this limiting factor. + +**Results:** + +The results of the report are shown below, the model generates plausible images, albeit extremely blurry, with much room for improvement. After 60 epochs, the VQVAE was able to produce accurate, albeit blurry, reconstructions. A graph of the losses is shown below. The individual losses (codebook loss, commitment loss and reconstruction loss) were checked numerous times to ensure expected behaviour. +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/fac1f5a9-42ce-4461-8418-52253bac6bf6) + +This VQVAE produced the outputs shown below: +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/2821e6e7-ee87-4b0e-9ae9-955b864f8df3) + + +These results demonstrated that the model was not overfitting and should be trained for more epochs. The loss, however, plateaued at roughly 0.3, which was an indication that the model was not complex enough to capture the detail of the dataset. The Pixel CNN was then trained for 25 epochs, with the results as shown below (Top: generated indices, Left: bottom half of the brain generated, Right: entire image generated): +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/a4f2a8d4-430a-4f4e-bd5d-5ce9487003c1) +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/944c35c0-4cee-4d98-90fa-32a8bcfb5a4b) +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/37bd9bd0-6869-4af4-8982-6a1405dc295a) + + +These results demonstrated the model working as intended, even after just 60 epochs and 25 epochs respectively. The loss graph for the PixelCNN is shown below (Note the x-axis for the validation loss represents the total number of samples, not the epochs) +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/727e40c9-7079-41d8-8b65-11f1dac486ba) + + +Overfitting was also avoided here, as can be seen from the graph, however the PixelCNN was also limited due to the VQVAE (which at this point had only trained for 60 epochs). +This led to the number of hidden channels in the encoder and decoder being increased slightly, as well as the embedding dimensions (if necessary, the number of embeddings can also be increased). +The new model performed far better than the old, (however did not save correctly due to failure of the technology used), and was able to complete the task. The model provided is able generate far clearer and more accurate image after approximately 80 and 50 epochs (without overfitting). +If even more accurate results are required, the overall number of parameters (particularly embedding dimension, number of embeddings), can be increased further. + +**How to Use:** + +In order to run this file, the environment must have to following installed: +- Pytorch +- TQDM +- Numpy +- PIL +- Matplotlib +Create a new conda environment with python 3.8 then run the command ‘pip install library’ for each of the libraries above in the command terminal to install any missing libraries. The OASIS dataset will also need to be downloaded and placed in a directory with three folders – one containing the train set, the validation set, and test set respectively. This report contains four main files: +- Modules: Containing the VQVAE and PixelCNN models +- Dataset: Classes for loading and preprocessing data +- Train: Functions for training and saving the models, as well as plotting the losses. +- Predict: Showing example usage of the trained models. +The ‘train’ and ‘predict’ files in particular have config specific variables (such as paths, embedding_dim etc.) which must be changed at the top of the file. +Any changes to the model and data can be made in Modules and Dataset respectively. To train the models, use the train.py file. Replace the current path variable with the local path to the OASIS dataset, and the names of the folders in the ‘DataPreparer’ instance, as well as the path where the models and losses will be saved. (Note that the training functions save the models as a new file every 20 epochs). The test functions in train.py also have optional parameters to visualise the data. +New images can be generated from the predict.py file. The function show_reconstructed_imgs(num_shown=2) shows the VQVAE input images, the codebook indices, and the reconstructions. +show_generated_indices(shown_imgs=2) shows the codebook indices generated by the PixelCNN. This calls cnn_model.sample() which takes in codebook indices, and generates any which have been replaced with ‘-1’. The line ‘masked_indices[:,:,16:,:] = -1’ can be changed to alter which indices are generated. +show_generated_output() shows a new generated image. Again, the masked_indices variable can be changed, so that parts of images can be generated based on existing pixels. + +**Data Processing:** + +In the dataset class, the data is loaded into datasets. The transforms applied to all data is conversion to a tensor and normalisation, based on the mean and standard deviation of the entire dataset. The number of images in the train, validation, and test dataset was of the approximate ratio 20:2:1. + +**Models Overview:** + +Note – the VQVAE model described contains 32 embedding dimensions. This was the original model, but this has since been increased. +VQVAE +The VQVAE model can be broken up into the encoder, vector quantiser, and decoder. Throughout the process, the data is generally in the shape (Batch, Channel, Height, Width). The batch size used is 128, as this was found to yield the best results overall. A visual depiction and descriptions of the VQ-VAE components are below. + ![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/0203148a-b794-4d82-862f-25fcaaa06baa) + +(Image from Neural Discrete Representation Learning2, 2018) +Encoder: +(B=128, C=1, H=256, W=256) -> (B=128, C=32, H=32, W=32) +The encoder used is a convolutional neural network which maps the grayscale (single channel) input images into a continuous latent space with 32 channels. Three convolutions are performed, and after each convolution, the data is normalised, and a non-linear activation function is applied. Each convolution has a stride of 2, and height and width in the latent space is reduced by a factor of 8. +In the current model used, the encoder changes the shape of the data from (128, 1, 256, 256) to (128, 32, 32, 32). To modify any layers of the model, changes can be made to the self.encoder attribute. (Note that the channels of the encoder and decoder should correspond to the embedding dimension of the vector quantiser). +Vector Quantiser: +(B=128, C=32, H=32, W=32) -> (B=128, C=32, H=32, W=32) +The encoder contains an embedded space/codebook (which it learns). For each pixel, a vector of the 32 channels associated with it is compared to each of the 32-dimensional embeddings in the embedded space (codebook). The index of the embedding which is closest (based on Euclidean distance) to the given vector is returned. The embedding vector which corresponds to this index is then found, and thus, the quantiser maps the 32-channel continuous space to a discrete space. +This process is broken into a number of functions, increasing the readability of the code and allowing for these functions to be called individually when generating images. The functions include: +- get_encoding_indices, which takes the original input (B, C, H, W), and saves the Euclidean distances to each embedding. It then finds the indices to the closest embedding, returning the codebook indices in the shape (B, H*W) +- output_from_indices, which takes the codebook indices of shape (B, H*W), and returns the corresponding embedding vector. It also reshapes the output into the correct shape (B, C, H, W) before returning it. +The forward function calls both these functions and contains a straight-through gradient estimator (shown by the red arrow in the diagram above). This is because the discretisation is non-differentiable, and therefore back-propagation could otherwise not occur. If x is the input to the quantiser and Z(x), the output. The output x + Z(x) - x is still Z(x), however detaching Z(x) and -x in ensures that the backpropagation can follow the dotted line shown below, avoiding the non-differentiable Z(x) block. + ![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/cb007ed1-244d-4931-b0e5-ab9a03007541) + +The forward function also returns the codebook indices (reshaped back to B, H, W), as well as the associated loss – comprised codebook loss and the commitment loss. The codebook corresponds to the distance between the latent vectors produced by the encoder and the closest vectors in the codebook, which ensures that the vectors in the embedded space are effective representations. The commitment loss prevents the codebook vectors from oscillating between values, encouraging them to converge. The vector quantiser has a ‘beta’ attribute (currently 0.2), and the quaniser loss is calculated as quant_loss = codebook_loss + beta*commitment_loss. The ‘beta’ value can be changed to affect the robustness and aggression of the quantiser. +Decoder: +(B=128, C=32, H=32, W=32) -> (B=128, C=1, H=256, W=256) +The decoder takes the output of the vector quantiser B, C, H, W (currently B, 32, 32, 32), and performs three transposed convolutions to attempt to reconstruct the original image. Each transposed convolution is again followed by normalisation and a non-linear activation function. + +The forward call of the VQVAE passes the image through the encoder, vector quantiser, then decoder. It calculates the reconstruction loss, which represents the dissimilarity between the input image and the output image. It returns the output image, vector quantiser loss, reconstruction loss, and the encoding indices (helpful for visualising). The VQVAE model also contains the function img_from_indices. This function is useful as it takes in the codebook indices (B, H, W) and allows for the constructed image to be returned directly, which is used during image generation. +PixelCNN +The autoregressive PixelCNN used is based on an implementation for MNIST data by Phillip Lippe1. +The pixel CNN attempts to learn the space of discrete codebook indices of the VQVAE, such that new plausible combinations can be predicted. This is achieved through a masked convolution, which does not allow the neural network to see the current or future pixels. This was achieved by creating a mask as shown below and multiplying with the kernel weight data before convolution. +image + + +However, using this kernel directly led to a blind spot in the receptive field, and therefore a gated convolution will be used, with both a vertical stack and horizontal stack, with architecture: +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/99deeeff-ca6d-48b7-b379-9a27b8f27944) + +(Figure by Aaron van den Oord et al.): +The neural network then learns to predict future indices. A number of classes were used in developing the PixelCNN: +- MaskedConvolution: This is the base class which takes a mask as a parameter and performs a masked convolution. +- VerticalConv: This extends MaskedConvolution and is a vertical stack convolution which considers all pixels above the current pixel. +- HorizontalConv: This extends MaskedConvolution and is a horizontal stack convolution which considers all pixels to the left of the current pixel. It also contains residual blocks, because the output of the horizontal convolution is used in prediction. +- GatedMaskedConv: This performs the horizontal and vertical stack convolutions, and combines them according to the graph above. +The PixelCNN scales all of the indices to a range between 0 and 1. It performs the initial convolutions (which mask the centre pixel), and then a number of gated-masked convolutions. It applies a non-linear activation function, and reshapes the output to (Batch, Classes, Channels, Height, Width). This allows for the Cross-entropy loss between the actual indices and the generated indices to be found. Note that for this model, the loss is given by the bpd (bits per dimension), which is calculated based on the cross-entropy, but is more suitable for the model. +The model also contains a function called ‘sample’. This takes an input of indices, with any indices to generate replaced with ‘-1’. It iterates through the pixels and channels and uses Softmax to predict which indices are likely. For full image generation, the indices can be replaced entirely with ‘-1’s. + +**References** + +1. Lippe, P. (no date) Tutorial 12: Autoregressive Image modelling, Tutorial 12: Autoregressive Image Modelling - UvA DL Notebooks v1.2 documentation. Available at: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial12/Autoregressive_Image_Modeling.html (Accessed: 09 October 2023) +2. Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu (30/05/2018), Neural Discrete Representation Learning. -In the recognition folder, you will find many recognition problems solved including: -* OASIS brain segmentation -* Classification -etc. diff --git a/dataset.py b/dataset.py new file mode 100644 index 000000000..70c2e219a --- /dev/null +++ b/dataset.py @@ -0,0 +1,70 @@ + +"""Classes for creating dataloaders of grayscale images from folders""" + + +import os +import torch +import torchvision.transforms as transforms +import numpy as np +from tqdm.auto import tqdm +from PIL import Image +import torch.utils.data + + +"""Prepares train, validation and test datasets""" +# Note, making datasets class variables can prevent needing to reload data every instance +# (or add conditional to see if data is previously loaded in init()) +class DataPreparer(): + + def __init__(self, path, train_folder, validation_folder, test_folder, batch_size): + # Transform (to be applied to all datasets) + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=0.13242, std=0.18826) + ]) + self.batch_size = batch_size + # Initialise the datasets + train_data = self.load_data_from_folder(path, train_folder) + validate_data = self.load_data_from_folder(path, validation_folder) + test_data = self.load_data_from_folder(path, test_folder) + + # Transform the data, and stack it into a tensor (B, C, H, W) + self.train_dataset = torch.stack([transform(item) for item in train_data]).permute(0, 2, 3, 1) + self.validate_dataset = torch.stack([transform(item) for item in validate_data]).permute(0, 2, 3, 1) + self.test_dataset = torch.stack([transform(item) for item in test_data]).permute(0, 2, 3, 1) + + # Create dataloaders + self.train_dataloader = self.prepare_dataset(self.train_dataset) + self.validate_dataloader = self.prepare_dataset(self.validate_dataset) + self.test_dataloader = self.prepare_dataset(self.test_dataset) + + """Takes in a dataset, returns a dataloader""" + def prepare_dataset(self, dataset): + # Create and return dataloader + dataloader = torch.utils.data.DataLoader(dataset, self.batch_size, shuffle=True) + return dataloader + + """Given a path and a folder name, returns a numpy array of the images (B, C=1, H, W)""" + def load_data_from_folder(self, path, name): + # Initialise data as an empty list + data = [] + i = 0 + # Make a list of image names + list_files = [f for f in os.listdir(path+name) if f.lower().endswith('.png')] + + for filename in tqdm(list_files): + # Load image as numpy array + image_path = os.path.join(path+name, filename) + # Assign and convert to grayscale + image = Image.open(image_path).convert('L') + image = np.array(image) + # Add channel, so img shape is 'C=1, H, W' + image = np.expand_dims(image, axis=0) + + data.append(image) + + if i == 50: + return data + i += 1 + + return np.array(data) \ No newline at end of file diff --git a/modules.py b/modules.py new file mode 100644 index 000000000..7707f4838 --- /dev/null +++ b/modules.py @@ -0,0 +1,284 @@ +"""Contains VQVAE and PixelCNN models""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + +# ---------------------------------------------------------------------------- +# VQVAE MODEL +# ---------------------------------------------------------------------------- + +"""The VQ-VAE Model""" +class VQVAE(nn.Module): + + def __init__(self, num_embeddings, embedding_dim): + super(VQVAE, self).__init__() + + self.epochs_trained = 0 + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + self.encoder = Encoder(embedding_dim) + self.quantiser = Quantiser(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.decoder = Decoder(embedding_dim) + + def forward(self, x): + # Input shape is B, C, H, W + quant_input = self.encoder(x) + quant_out, quant_loss, encoding_indices = self.quantiser(quant_input) + output = self.decoder(quant_out) + + # Reconstruction Loss, and find the total loss + reconstruction_loss = F.mse_loss(x, output) + + return output, quant_loss, reconstruction_loss, encoding_indices + + """Function while allows output to be calculated directly from indices + param quant_out_shape is the shape that the quantiser is expected to return""" + @torch.no_grad() + def img_from_indices(self, indices, quant_out_shape_BHWC): + quant_out = self.quantiser.output_from_indices(indices, quant_out_shape_BHWC) # Output is currently 32*32 img with 32 channels + return self.decoder(quant_out) + +"""The Encoder Model used in VQ-VAE""" +class Encoder(nn.Module): + def __init__(self, embedding_dim): + super(Encoder, self).__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, embedding_dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(embedding_dim), + nn.ReLU(), + ) + + def forward(self, x): + out = self.encoder(x) + return out + +"""The VectorQuantiser Model used in VQ-VAE""" +class Quantiser(nn.Module): + def __init__(self, num_embeddings, embedding_dim) -> None: + super(Quantiser, self).__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.beta = 0.2 + + self.embedding = self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + + """Returns the encoding indices from the input""" + def get_encoding_indices(self, quant_input): + # Flatten + quant_input = quant_input.permute(0, 2, 3, 1) + quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) + + # Compute pairwise distances + dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) + + # Find index of nearest embedding + encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H + return encoding_indices + + """Returns the output from the encoding indices without calculating losses etc.""" + def output_from_indices(self, indices, output_shape_BHWC): + quant_out = torch.index_select(self.embedding.weight, 0, indices.view(-1)) + quant_out = quant_out.reshape(output_shape_BHWC).permute(0, 3, 1, 2) + return quant_out + + def forward(self, quant_input): + + B, C, H, W = quant_input.shape + + # Finds the encoding indices + encoding_indices = self.get_encoding_indices(quant_input) + + quant_input = quant_input.permute(0, 2, 3, 1) + quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) + + # Gets the output based on the encoding indices + quant_out = torch.index_select(self.embedding.weight, 0, encoding_indices.view(-1)) + + quant_input = quant_input.reshape((-1, quant_input.size(-1))) + + # Losses + commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) + codebook_loss = torch.mean((quant_out - quant_input.detach())**2) + loss = codebook_loss + self.beta*commitment_loss + + # Straight through gradient estimator for backprop + quant_out = quant_input + (quant_out - quant_input).detach() + + # Reshape quant_out + quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) + # Reshapes encoding indices to 'B, H, W' + encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) + + return quant_out, loss, encoding_indices + +"""The Decoder Model used in VQ-VAE""" +class Decoder(nn.Module): + def __init__(self, embedding_dim) -> None: + super(Decoder, self).__init__() + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(embedding_dim, 32, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(1), + ) + + def forward(self, x): + out = self.decoder(x) + return out + + + +# ---------------------------------------------------------------------------- +# PixelCNN model +# ---------------------------------------------------------------------------- + +"""Autoregressive PixelCNN model""" +class PixelCNN(nn.Module): + + def __init__(self, in_channels, hidden_channels, num_embeddings): + super(PixelCNN, self).__init__() + + self.epochs_trained = 0 + + # Equal to the number of embeddings in the VQVAE + self.num_embeddings = num_embeddings + # Initial convolutions skipping the center pixel + self.conv_vstack = VerticalConv(in_channels, hidden_channels, mask_center=True) + self.conv_hstack = HorizontalConv(in_channels, hidden_channels, mask_center=True) + # Convolution block of PixelCNN. Uses dilation instead of downscaling + self.conv_layers = nn.ModuleList([ + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=2), + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=4), + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=2), + GatedMaskedConv(hidden_channels) + ]) + # Output classification convolution (1x1) + # The output channels should be in_channels*number of embeddings to learn continuous space and calc. CrossEntropyLoss + self.conv_out = nn.Conv2d(hidden_channels, in_channels*self.num_embeddings, kernel_size=1, padding=0) + + + def forward(self, x): + # Scale input from 0 to 255 to -1 to 1 + x = (x.float() / 255.0) * 2 - 1 + + # Initial convolutions + v_stack = self.conv_vstack(x) + h_stack = self.conv_hstack(x) + # Gated Convolutions + for layer in self.conv_layers: + v_stack, h_stack = layer(v_stack, h_stack) + # 1x1 classification convolution + # Apply ELU (exponential activation function) before 1x1 convolution for non-linearity on residual connection + out = self.conv_out(F.elu(h_stack)) + + # Output dimensions: [Batch, Classes, Channels, Height, Width] (classes = num_embeddings) + out = out.reshape(out.shape[0], self.num_embeddings, out.shape[1]//256, out.shape[2], out.shape[3]) + return out + + """Indices shape should be in form B C H W + Pixels to fill should be marked with -1""" + @torch.no_grad() + def sample(self, ind_shape, ind): + # Generation loop (iterating through pixels across channels) + for h in range(ind_shape[2]): # Heights + for w in range(ind_shape[3]): # Widths + for c in range(ind_shape[1]): # Channels + # Skip if not to be filled (-1) + if (ind[:,c,h,w] != -1).all().item(): + continue + # Only have to input upper half of ind (rest are masked anyway) + pred = self.forward(ind[:,:,:h+1,:]) + probs = F.softmax(pred[:,:,c,h,w], dim=-1) + ind[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) + return ind + + +"""A general Masked convolution, with a the mask as a parameter.""" +class MaskedConvolution(nn.Module): + + def __init__(self, in_channels, out_channels, mask, dilation=1): + + super(MaskedConvolution, self).__init__() + kernel_size = (mask.shape[0], mask.shape[1]) + padding = ([dilation*(kernel_size[i] - 1) // 2 for i in range(2)]) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) + + # Mask as buffer (must be moved with devices) + self.register_buffer('mask', mask[None,None]) + + def forward(self, x): + self.conv.weight.data *= self.mask # Set all following weights to 0 (make sure it is in GPU) + return self.conv(x) + + +class VerticalConv(MaskedConvolution): + # Masks all pixels below + def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): + mask = torch.ones(kernel_size, kernel_size) + mask[kernel_size//2+1:,:] = 0 + # For the first convolution, mask center row + if mask_center: + mask[kernel_size//2,:] = 0 + + super().__init__(in_channels, out_channels, mask, dilation=dilation) + +class HorizontalConv(MaskedConvolution): + + def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): + # Mask out all pixels on the left. (Note that kernel has a size of 1 + # in height because we only look at the pixel in the same row) + mask = torch.ones(1,kernel_size) + mask[0,kernel_size//2+1:] = 0 + + # For first convolution, mask center pixel + if mask_center: + mask[0,kernel_size//2] = 0 + + super().__init__(in_channels, out_channels, mask, dilation=dilation) + +"""Gated Convolutions Model""" +class GatedMaskedConv(nn.Module): + + def __init__(self, in_channels, dilation=1): + + super(GatedMaskedConv, self).__init__() + self.conv_vert = VerticalConv(in_channels, out_channels=2*in_channels, dilation=dilation) + self.conv_horiz = HorizontalConv(in_channels, out_channels=2*in_channels, dilation=dilation) + self.conv_vert_to_horiz = nn.Conv2d(2*in_channels, 2*in_channels, kernel_size=1, padding=0) + self.conv_horiz_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) + + def forward(self, v_stack, h_stack): + # Vertical stack (left) + v_stack_feat = self.conv_vert(v_stack) + v_val, v_gate = v_stack_feat.chunk(2, dim=1) + v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate) + + # Horizontal stack (right) + h_stack_feat = self.conv_horiz(h_stack) + h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat) + h_val, h_gate = h_stack_feat.chunk(2, dim=1) + h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate) + h_stack_out = self.conv_horiz_1x1(h_stack_feat) + h_stack_out = h_stack_out + h_stack + + return v_stack_out, h_stack_out \ No newline at end of file diff --git a/predict b/predict new file mode 100644 index 000000000..47b9e0c96 --- /dev/null +++ b/predict @@ -0,0 +1,122 @@ +"""Example Predictions""" + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import numpy as np +from tqdm.auto import tqdm +from PIL import Image +import torch.utils.data +from torchvision import datasets, transforms, utils +import matplotlib.pyplot as plt + +import modules +import dataset + +# Replace with preferred device and local path(s) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print("Torch version ", torch.__version__) +path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +vqvae_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/" +pixelCNN_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/" + +# Load the models +vqvae_model = torch.load(vqvae_save_path + "trained_vqvae.pth", map_location=device) +cnn_model = torch.load(pixelCNN_save_path + "PixelCNN model.pth", map_location=device) + +encoder = vqvae_model.__getattr__("encoder") +quantiser = vqvae_model.__getattr__("quantiser") +decoder = vqvae_model.__getattr__("decoder") + +# Update this parameter if using a newer model +embedding_dim = 32 + +# Load data +processed_data = dataset.DataPreparer(path, "keras_png_slices_train/", "keras_png_slices_validate/", "keras_png_slices_test/", batch_size=128) + +# VQVAE outputs +def show_reconstructed_imgs(num_shown=2): + input_imgs = processed_data.test_dataset[0:num_shown] + input_imgs = input_imgs.to(device, dtype=torch.float32) + with torch.no_grad(): + output_imgs, _, _, encoding_indices = vqvae_model(input_imgs) + + fig, ax = plt.subplots(num_shown, 3) + plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0) + ax[0, 0].set_title("Input Image") + ax[0, 1].set_title("CodeBook Indices") + ax[0, 2].set_title("Reconstructed Image") + for i in range(num_shown): + for j in range(3): + ax[i, j].axis('off') + ax[i, 0].imshow(input_imgs[i][0].cpu().numpy(), cmap='gray') + ax[i, 1].imshow(encoding_indices[i].cpu().numpy(), cmap='gray') + ax[i, 2].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray') + + plt.show() + +# PixelCNN Indices Generation +def show_generated_indices(shown_imgs=2): + + print(" > Showing Images") + + # Inputs + test_batch = processed_data.test_dataset[0:shown_imgs] + + encoder_output = encoder(test_batch.to(device)) + _, _, indices = quantiser(encoder_output) + + indices_shape = indices.cpu().numpy().shape + + #print("Indices shape is: ", indices.cpu().numpy().shape) + indices = indices.reshape((indices_shape[0], 1,indices_shape[1], indices_shape[2])) + #print("Indices shape is: ", indices.cpu().numpy().shape) + + # Masked Inputs (only top half shown to model) + masked_indices = 1*indices + masked_indices[:,:,16:,:] = -1 + + gen_indices = cnn_model.sample((shown_imgs, 1, 32, 32), ind=masked_indices*1) + + fig, ax = plt.subplots(shown_imgs, 3) + + for a in ax.flatten(): + a.axis('off') + + ax[0, 0].set_title("Real") + ax[0, 1].set_title("Masked") + ax[0, 2].set_title("Generated") + + for i in range(shown_imgs): + ax[i, 0].imshow(indices[i][0].long().cpu().numpy(), cmap='gray') + ax[i, 1].imshow(masked_indices[i][0].cpu().numpy(), cmap='gray') + ax[i, 2].imshow(gen_indices[i][0].cpu().numpy(), cmap='gray') + plt.show() + +def show_generated_output(): + # Inputs + test_batch = processed_data.test_dataset[0:1] + + encoder_output = encoder(test_batch.to(device)) + _, _, indices = quantiser(encoder_output) + + indices_shape = indices.cpu().numpy().shape + + indices = indices.reshape((indices_shape[0], 1,indices_shape[1], indices_shape[2])) + + # Masked Inputs (only top half shown to model) + masked_indices = 1*indices + masked_indices[:,:,:,:] = -1 + + + gen_indices = cnn_model.sample((1, 1, 32, 32), ind=masked_indices*1) + + # Change the last 32 using a new model with embedding_dim > 32 + plt.imshow(vqvae_model.img_from_indices(gen_indices[0], (1, 32, 32, embedding_dim))[0][0].cpu().numpy(), cmap='gray') + plt.show() + +show_reconstructed_imgs() +show_generated_indices() +show_generated_output() diff --git a/train.py b/train.py new file mode 100644 index 000000000..754c83437 --- /dev/null +++ b/train.py @@ -0,0 +1,308 @@ +"""Training of the VQVAE and PixelCNN""" + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import numpy as np +from tqdm.auto import tqdm +from PIL import Image +import torch.utils.data +from torchvision import datasets, transforms, utils +import matplotlib.pyplot as plt + +import modules +import dataset + +# Replace with preferred device and local path(s) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print("Torch version ", torch.__version__) +path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +vqvae_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/" +pixelCNN_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/" + + +# Hyperparameters +batch_size = 128 +vqvae_num_epochs = 10 +vqvae_lr = 5e-4 +cnn_num_epochs = 10 +cnn_lr = 5e-4 + +# Data (If necessary, replace with the local names of the train, validate and test folders) +print("> Loading Data") +processed_data = dataset.DataPreparer(path, "keras_png_slices_train/", "keras_png_slices_validate/", "keras_png_slices_test/", batch_size) +train_dataloader = processed_data.train_dataloader +validate_dataloader = processed_data.validate_dataloader +test_dataloader = processed_data.test_dataloader + +# Models (if not already saved) +vqvae_model = modules.VQVAE(num_embeddings=256, embedding_dim=64).to(device) +cnn_model = modules.PixelCNN(in_channels=1, hidden_channels=128, num_embeddings=256).to(device) +# Models (if previously saved). +#vqvae_model = torch.load(vqvae_save_path + "trained_vqvae_n.pth", map_location=device) +#cnn_model = torch.load(pixelCNN_save_path + "PixelCNN model_n.pth", map_location=device) + +# Optimisers +vqvae_optimiser = torch.optim.Adam(vqvae_model.parameters(), vqvae_lr) +cnn_optimiser = torch.optim.Adam(cnn_model.parameters(), cnn_lr) + +# Initialise losses (for graphing) +vqvae_training_loss = [] +vqvae_validation_loss = [] +cnn_training_loss = [] +cnn_validation_loss = [] + + +# -------------------------------------------------------- +# VQVAE functions +# -------------------------------------------------------- + +def train_vqvae(): + + print("> Training VQVAE") + + for epoch_num, epoch in enumerate(range(vqvae_num_epochs)): + + # Train + vqvae_model.train() + for train_batch in train_dataloader: + images = train_batch.to(device, dtype=torch.float32) + + output, quant_loss, reconstruction_loss, _ = vqvae_model(images) + training_loss = quant_loss + reconstruction_loss # Can be adjusted if necessary + + vqvae_optimiser.zero_grad() # Reset gradients to zero + training_loss.backward() # Calculate grad + vqvae_optimiser.step() # Adjust weights + + with torch.no_grad(): + vqvae_training_loss.append((quant_loss.detach().cpu(), reconstruction_loss.detach().cpu(), training_loss.detach().cpu())) + + # Evaluate + vqvae_model.eval() + for validate_batch in (validate_dataloader): + images = validate_batch.to(device, dtype=torch.float32) + + with torch.no_grad(): + output, quant_loss, reconstruction_loss, _ = vqvae_model(images) + validation_loss = quant_loss + reconstruction_loss + + with torch.no_grad(): + vqvae_validation_loss.append((quant_loss.detach().cpu(), reconstruction_loss.detach().cpu(), validation_loss.detach().cpu())) + + vqvae_model.epochs_trained += 1 + print("Epoch {} of {}. Training Loss: {}, Validation Loss: {}".format(epoch_num+1, vqvae_num_epochs, training_loss, validation_loss)) + + if (epoch_num+1) % 20 == 0: + print("Saving VQVAE model") + torch.save(vqvae_model, vqvae_save_path + "trained_vqvae_{}.pth".format(vqvae_model.epochs_trained)) + + +def plot_vqvae_losses(show_individual_losses=False): + # Losses are in the order (Quant, Reconstruction, Total) + plt.title("VQVAE Losses") + plt.xlabel("Epoch") + plt.ylabel("Loss") + + if show_individual_losses == False: + plt.plot([loss[2] for loss in vqvae_training_loss], color='blue') + plt.plot([loss[2] for loss in vqvae_validation_loss], color='red') + plt.legend(["Training Loss", "Validation Loss"]) + else: + plt.plot([loss[0] for loss in vqvae_training_loss], color='blue', ls='--') + plt.plot([loss[0] for loss in vqvae_validation_loss], color='red', ls='--') + plt.plot([loss[1] for loss in vqvae_training_loss], color='blue') + plt.plot([loss[1] for loss in vqvae_validation_loss], color='red') + plt.legend(["Training Quantisation Loss", "Validation Quantisation Loss", "Training Reconstruction Loss", "Validation Reconstruction Loss"]) + plt.show() + + +"""Function to test the VQVAE. Input the number of samples to show""" +def test_vqvae(num_shown=0): + + print("> Testing VQVAE") + + # Calculate losses + vqvae_model.eval() + test_losses = [] + for test_batch in (test_dataloader): + images = test_batch.to(device, dtype=torch.float32) + with torch.no_grad(): + output, quant_loss, r_loss, _ = vqvae_model(images) + # For averaging loss + test_losses.append((quant_loss + r_loss).cpu()) + + print("Average loss during testing is: ", np.mean(np.array(test_losses))) + + # Show N Reconstructed Images. + if (num_shown != 0): + input_imgs = processed_data.test_dataset[0:num_shown] + input_imgs = input_imgs.to(device, dtype=torch.float32) + with torch.no_grad(): + output_imgs, _, _, encoding_indices = vqvae_model(input_imgs) + + fig, ax = plt.subplots(num_shown, 3) + plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0) + ax[0, 0].set_title("Input Image") + ax[0, 1].set_title("CodeBook Indices") + ax[0, 2].set_title("Reconstructed Image") + for i in range(num_shown): + for j in range(3): + ax[i, j].axis('off') + ax[i, 0].imshow(input_imgs[i][0].cpu().numpy(), cmap='gray') + ax[i, 1].imshow(encoding_indices[i].cpu().numpy(), cmap='gray') + ax[i, 2].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray') + + plt.show() + + +# Uncomment any functions to call +train_vqvae() +plot_vqvae_losses() +test_vqvae(num_shown=3) + + +# -------------------------------------------------------- +# PixCNN functions +# -------------------------------------------------------- + +encoder = vqvae_model.__getattr__("encoder") +quantiser = vqvae_model.__getattr__("quantiser") +decoder = vqvae_model.__getattr__("decoder") + + +def train_pixcnn(): + + print("> Training PixelCNN") + + for epoch_num, epoch in enumerate(range(cnn_num_epochs)): + + cnn_model.train() + + for train_batch in train_dataloader: + + # Get the quantised outputs + with torch.no_grad(): + encoder_output = encoder(train_batch.to(device)) + _, _, indices = quantiser(encoder_output) + indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) + + output = cnn_model(indices) + + # Compute loss + nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood + bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension + training_loss = bpd.mean() + + cnn_optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) + training_loss.backward() # Calculate grad + cnn_optimiser.step() # Adjust weights + + with torch.no_grad(): + cnn_training_loss.append(training_loss.cpu()) + + cnn_model.eval() + + for validate_batch in validate_dataloader: + with torch.no_grad(): + # Get the quantised outputs + encoder_output = encoder(validate_batch.to(device)) + _, _, indices = quantiser(encoder_output) + indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) + output = cnn_model(indices) + + # Compute loss + nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood + bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension + validation_loss = bpd.mean() + + with torch.no_grad(): + cnn_validation_loss.append(validation_loss.cpu()) + + cnn_model.epochs_trained += 1 + print("Epoch {} of {}. Training Loss: {}, Validation Loss: {}".format(epoch_num+1, cnn_num_epochs, training_loss, validation_loss)) + + if (epoch_num+1) % 20 == 0: + print("Saving Pixel CNN") + torch.save(cnn_model, pixelCNN_save_path + "PixelCNN model_{}.pth".format(cnn_model.epochs_trained)) + + +def plot_cnn_loss(): + # Losses are in the order (Quant, Reconstruction, Total) + plt.title("PixelCNN Losses") + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.plot(cnn_training_loss, color='blue') + plt.plot(cnn_validation_loss, color='red') + plt.legend(["Training Loss", "Validation Loss"]) + plt.show() + + +def test_cnn(shown_imgs=0): + print("> Testing PixelCNN") + + test_loss = [] + + cnn_model.eval() + + with torch.no_grad(): + for test_batch in test_dataloader: + # Get the quantised outputs + encoder_output = encoder(test_batch.to(device)) + _, _, indices = quantiser(encoder_output) + indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) + output = cnn_model(indices) + + # Compute loss + nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood + bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension + validation_loss = bpd.mean() + test_loss.append(validation_loss.cpu()) + + print("Average loss during testing is: ", np.mean(np.array(test_loss))) + + if shown_imgs != 0: + print(" > Showing Images") + + # Inputs + test_batch = processed_data.test_dataset[0:shown_imgs] + + encoder_output = encoder(test_batch.to(device)) + _, _, indices = quantiser(encoder_output) + + indices_shape = indices.cpu().numpy().shape + + print("Indices shape is: ", indices.cpu().numpy().shape) + indices = indices.reshape((indices_shape[0], 1,indices_shape[1], indices_shape[2])) + print("Indices shape is: ", indices.cpu().numpy().shape) + + # Masked Inputs (only top quarter shown) + masked_indices = 1*indices + masked_indices[:,:,16:,:] = -1 + + gen_indices = cnn_model.sample((shown_imgs, 1, 32, 32), ind=masked_indices*1) + + fig, ax = plt.subplots(shown_imgs, 3) + + for a in ax.flatten(): + a.axis('off') + + ax[0, 0].set_title("Real") + ax[0, 1].set_title("Masked") + ax[0, 2].set_title("Generated") + + for i in range(shown_imgs): + ax[i, 0].imshow(indices[i][0].long().cpu().numpy(), cmap='gray') + ax[i, 1].imshow(masked_indices[i][0].cpu().numpy(), cmap='gray') + ax[i, 2].imshow(gen_indices[i][0].cpu().numpy(), cmap='gray') + plt.show() + + plt.imshow(vqvae_model.img_from_indices(gen_indices[0], (1, 32, 32, 32))[0][0].cpu().numpy(), cmap='gray') + plt.show() + +# Uncomment any functions to call +train_pixcnn() +plot_cnn_loss() +test_cnn(shown_imgs=3)