From 7a431c0c742b1cb0195df3d9e41884993e9ffb19 Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Tue, 10 Oct 2023 15:23:13 +1000 Subject: [PATCH 01/26] Implemented OASIS data loader --- VQ_VAE_46992925/VQ_VAE | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 VQ_VAE_46992925/VQ_VAE diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE new file mode 100644 index 000000000..5fda7257f --- /dev/null +++ b/VQ_VAE_46992925/VQ_VAE @@ -0,0 +1,42 @@ +''' +VQ-VAE +''' +import os +import torch +import torch.nn as nn +import torchvision.transforms as transforms +import numpy as np +from tqdm import tqdm +from PIL import Image +import torch.utils.data +from torchvision import datasets, transforms, utils +import matplotlib.pyplot as plt + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print("Torch version ", torch.__version__) + + +path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +output_path = path+"data/mnist/output_torch_"+"GAN"+"/" + + + + +def load_data_from_folder(name): + data = [] + for filename in os.listdir(path+name): + + image_path = os.path.join(path+name, filename) + image = Image.open(image_path) + image = np.array(image) + data.append(image) + + return np.array(data) + +print("> Loading Test data") + +train_data = torch.from_numpy(load_data_from_folder("keras_png_slices_test/")).to(device) +train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) + +print("> Test Data Finsihed Loading") +print("The shape of the data is: ", train_data.shape) \ No newline at end of file From 9c26e5cd0a31c5e0756474d2fa3631187a5f755c Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Tue, 10 Oct 2023 20:59:18 +1000 Subject: [PATCH 02/26] Implemented basic VQVAE model --- VQ_VAE_46992925/VQ_VAE | 84 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 79 insertions(+), 5 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 5fda7257f..7bbf21318 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -1,5 +1,7 @@ ''' VQ-VAE +Model as implemented by +https://www.youtube.com/watch?v=1ZHzAOutcnw ''' import os import torch @@ -16,11 +18,10 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Torch version ", torch.__version__) -path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" -output_path = path+"data/mnist/output_torch_"+"GAN"+"/" - - +# ------------------------------------------------ +# Data Loader +path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" def load_data_from_folder(name): data = [] @@ -39,4 +40,77 @@ train_data = torch.from_numpy(load_data_from_folder("keras_png_slices_test/")).t train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) print("> Test Data Finsihed Loading") -print("The shape of the data is: ", train_data.shape) \ No newline at end of file +print("The shape of the data is: ", train_data.shape) + + + +# ------------------------------------------------ +# Model + +class VQVAE(nn.Module): + def __init__(self, ): + super(VQVAE, self).__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 4, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(4), + nn.ReLU(), + ) + + self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1) # TODO FC layer?? + self.embedding = nn.Embedding(num_embeddings=3, embedding_dim=2) + self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1) + + # Commitment loss beta + self.beta = 0.2 + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(4, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), + nn.Tanh(), + ) + + def forward(self, x): + # B, C, H, W + encoded_output = self.encoder(x) + quant_input = self.pre_quant_conv(encoded_output) + + # Quantisation + B, C, H, W = quant_input.shape + quant_input = quant_input.permute(0, 2, 3, 1) + quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) + + # Compute pairwise distances + dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) + + # Find index of nearest embedding + min_encoding_indices = torch.argmin(dist, dim=-1) + + # Select the embedding weights + quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) + + quant_input = quant_input.reshape((-1, quant_input.size(-1))) + + # Compute losses + commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) + codebook_loss = torch.mean((quant_out - quant_input.detach())**2) + total_losses = codebook_loss + self.beta*commitment_loss + + # Straight through gradient estimator + quant_out = quant_input + (quant_out - quant_input).detach() # Detach ~ ignored for back-prop + quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) + min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) + + + # Decoding + decoder_input = self.post_quant_conv(quant_out) + output = self.decoder(decoder_input) + return output, total_losses + +model = VQVAE() +print(model) \ No newline at end of file From 35e6c28dbc545dcf736f203f8e96f6135256e861 Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Wed, 11 Oct 2023 10:16:14 +1000 Subject: [PATCH 03/26] Added basic training prototype to test --- VQ_VAE_46992925/VQ_VAE | 62 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 7 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 7bbf21318..e88ca8c8a 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -6,6 +6,7 @@ https://www.youtube.com/watch?v=1ZHzAOutcnw import os import torch import torch.nn as nn +import torch.nn.functional as F import torchvision.transforms as transforms import numpy as np from tqdm import tqdm @@ -36,8 +37,14 @@ def load_data_from_folder(name): print("> Loading Test data") -train_data = torch.from_numpy(load_data_from_folder("keras_png_slices_test/")).to(device) -train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) +########## TODO THIS IS NOT SEGMENTATION!! SPLIT TRAINING INTO TRAIN AND TEST!!! + +train_data = torch.from_numpy(load_data_from_folder("keras_png_slices_train/")).to(device) +test_data = torch.from_numpy(load_data_from_folder("keras_png_slices_test/")).to(device) + +train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) +test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True) + print("> Test Data Finsihed Loading") print("The shape of the data is: ", train_data.shape) @@ -99,18 +106,59 @@ class VQVAE(nn.Module): # Compute losses commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) codebook_loss = torch.mean((quant_out - quant_input.detach())**2) - total_losses = codebook_loss + self.beta*commitment_loss + # Straight through gradient estimator quant_out = quant_input + (quant_out - quant_input).detach() # Detach ~ ignored for back-prop quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) - - + # Decoding decoder_input = self.post_quant_conv(quant_out) output = self.decoder(decoder_input) + + # Reconstruction Loss, and find the total loss + reconstruction_loss = F.mse_loss(x, output) + total_losses = reconstruction_loss + codebook_loss + self.beta*commitment_loss + return output, total_losses -model = VQVAE() -print(model) \ No newline at end of file + +# ------------------------------------------------ +# Training + +########################## TODO THERE IS NO RECONSTRUCTION LOSS!! + +# Hyperparams +learning_rate = 1.e-3 +num_epochs = 3 + + +model = VQVAE().to(device) +print(model) + +optimiser = torch.optim.Adam(model.parameters(), learning_rate) + +for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): + model.train() + for train_batch in tqdm(train_dataloader): + images, labels = train_batch + images = images.to(device) + + output, total_losses = model(images) + + optimiser.zero_grad() # Reset gradients to zero for back-prop + total_losses.backward() # Calculate grad + optimiser.step() + + # Evaluate + model.eval() + + for test_batch in tqdm(test_dataloader): + images, labels = test_batch + images = images.to(device) + + with torch.no_grad(): + output, total_losses = model(images) + + print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) \ No newline at end of file From f0e1c6375be02725acf45b86f82adad26bf432bc Mon Sep 17 00:00:00 2001 From: Dru Callaghan Date: Wed, 11 Oct 2023 15:48:11 +1000 Subject: [PATCH 04/26] Implemented first trainable model (not converging) --- VQ_VAE_46992925/VQ_VAE | 98 ++++++++++++++++++++++++++++++++---------- 1 file changed, 76 insertions(+), 22 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index e88ca8c8a..75ab2199c 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -8,7 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms -import numpy as np +import numpy as np from tqdm import tqdm from PIL import Image import torch.utils.data @@ -23,32 +23,41 @@ print("Torch version ", torch.__version__) # Data Loader path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" + def load_data_from_folder(name): data = [] - for filename in os.listdir(path+name): + + for filename in tqdm(os.listdir(path+name)): # tqdm adds loading bar! image_path = os.path.join(path+name, filename) - image = Image.open(image_path) - image = np.array(image) + image = Image.open(image_path).convert('L') # Convert to grayscale (single channel) + image = np.array(image) + + # ENSURE THERE IS CHANNEL AS WELL: + image = image.reshape((1, image.shape[0], image.shape[1])) + data.append(image) - - return np.array(data) -print("> Loading Test data") + return np.array(data) -########## TODO THIS IS NOT SEGMENTATION!! SPLIT TRAINING INTO TRAIN AND TEST!!! +####### TODO THIS IS NOT SEGMENTATION!! SPLIT TRAINING INTO TRAIN AND TEST!!! +# C, H, W +print("> Loading Training data") train_data = torch.from_numpy(load_data_from_folder("keras_png_slices_train/")).to(device) +print("> Loading Test data") test_data = torch.from_numpy(load_data_from_folder("keras_png_slices_test/")).to(device) +# B, C, H, W train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True) +# TODO NORMALISE!!!!! -print("> Test Data Finsihed Loading") -print("The shape of the data is: ", train_data.shape) - +print("> Data Loading Finished") +print("The shape of the (training) data is: ", train_data.shape) # ------------------------------------------------ @@ -62,13 +71,16 @@ class VQVAE(nn.Module): nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), nn.Conv2d(16, 4, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(4), nn.ReLU(), ) self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1) # TODO FC layer?? - self.embedding = nn.Embedding(num_embeddings=3, embedding_dim=2) + self.embedding = nn.Embedding(num_embeddings=64, embedding_dim=2) self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1) # Commitment loss beta @@ -78,8 +90,11 @@ class VQVAE(nn.Module): nn.ConvTranspose2d(4, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(), + nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), - nn.Tanh(), + nn.Sigmoid(), ) def forward(self, x): @@ -104,7 +119,7 @@ class VQVAE(nn.Module): quant_input = quant_input.reshape((-1, quant_input.size(-1))) # Compute losses - commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) + commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) # TODO change to MSE codebook_loss = torch.mean((quant_out - quant_input.detach())**2) @@ -120,6 +135,8 @@ class VQVAE(nn.Module): # Reconstruction Loss, and find the total loss reconstruction_loss = F.mse_loss(x, output) total_losses = reconstruction_loss + codebook_loss + self.beta*commitment_loss + print("The reconstruction loss makes up {}% of the total loss ({}/{})" + .format(reconstruction_loss*100//(total_losses), int(reconstruction_loss), int(total_losses))) return output, total_losses @@ -129,9 +146,10 @@ class VQVAE(nn.Module): ########################## TODO THERE IS NO RECONSTRUCTION LOSS!! + # Hyperparams learning_rate = 1.e-3 -num_epochs = 3 +num_epochs = 5 model = VQVAE().to(device) @@ -142,23 +160,59 @@ optimiser = torch.optim.Adam(model.parameters(), learning_rate) for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): model.train() for train_batch in tqdm(train_dataloader): - images, labels = train_batch - images = images.to(device) + images = train_batch + images = images.to(device, dtype=torch.float32) output, total_losses = model(images) - optimiser.zero_grad() # Reset gradients to zero for back-prop + optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) total_losses.backward() # Calculate grad - optimiser.step() + optimiser.step() # Adjust weights # Evaluate model.eval() for test_batch in tqdm(test_dataloader): - images, labels = test_batch - images = images.to(device) + images = test_batch + + images = images.to(device, dtype=torch.float32) # (Set as float to ensure weights input are the same type) with torch.no_grad(): output, total_losses = model(images) + - print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) \ No newline at end of file + print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) + + + + +# ------------------------------------------------- +# Visualise + +# C, H, W +input_img = test_data[0] +input_img = input_img.reshape(1, 1, input_img.size(-2), input_img.size(-1)) +input_img = input_img.to(device, dtype=torch.float32) + +# DEBUGGING Print the input image shape and show it. +print("Shape of the input img is: ", input_img.shape) +#plt.imshow(input_img[0][0].cpu().numpy()) +#plt.gray() +#plt.show() + + +with torch.no_grad(): # Ensure no gradient calculation + output, _ = model(input_img) # Forward pass through the model + +print("Shape of the output img is: ", output.shape) + +# Display input and output images +plt.figure(figsize=(10, 5)) +plt.subplot(1, 2, 1) +plt.title("Input Image") +plt.imshow(input_img[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel input + +plt.subplot(1, 2, 2) +plt.title("Model Output") +plt.imshow(output[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel output +plt.show() \ No newline at end of file From 55c3f7c7727bce139fcf84c514ca5bf1495e6e71 Mon Sep 17 00:00:00 2001 From: Dru Callaghan Date: Wed, 11 Oct 2023 16:29:33 +1000 Subject: [PATCH 05/26] Non-working model. Changing data Normalisation and losses --- VQ_VAE_46992925/VQ_VAE | 43 ++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 75ab2199c..efe55e5e5 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -28,6 +28,8 @@ path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/ker def load_data_from_folder(name): data = [] + + i = 0 for filename in tqdm(os.listdir(path+name)): # tqdm adds loading bar! @@ -40,15 +42,28 @@ def load_data_from_folder(name): data.append(image) - return np.array(data) + if i == 50: + return np.array(data) + i += 1 -####### TODO THIS IS NOT SEGMENTATION!! SPLIT TRAINING INTO TRAIN AND TEST!!! + return np.array(data) -# C, H, W +# C, H, W (Numpy array) print("> Loading Training data") -train_data = torch.from_numpy(load_data_from_folder("keras_png_slices_train/")).to(device) +train_data = load_data_from_folder("keras_png_slices_train/") print("> Loading Test data") -test_data = torch.from_numpy(load_data_from_folder("keras_png_slices_test/")).to(device) +test_data = load_data_from_folder("keras_png_slices_test/") + +print("The shape of the (training) data is: ", train_data.shape) + +# Normalise +mean = np.mean(train_data) +std = np.std(train_data) +data_transform = transforms.Compose([transforms.ToTensor(), # Convert to PyTorch tensor + transforms.Normalize(mean, std)]) + +train_data = (data_transform(train_data)).to(device) +test_data = (data_transform(test_data)).to(device) # B, C, H, W train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) @@ -57,7 +72,6 @@ test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle # TODO NORMALISE!!!!! print("> Data Loading Finished") -print("The shape of the (training) data is: ", train_data.shape) # ------------------------------------------------ @@ -80,11 +94,12 @@ class VQVAE(nn.Module): ) self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1) # TODO FC layer?? - self.embedding = nn.Embedding(num_embeddings=64, embedding_dim=2) + self.embedding = nn.Embedding(num_embeddings=256, embedding_dim=2) self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1) # Commitment loss beta self.beta = 0.2 + self.alpha = 0.1 self.decoder = nn.Sequential( nn.ConvTranspose2d(4, 16, kernel_size=4, stride=2, padding=1), @@ -134,9 +149,11 @@ class VQVAE(nn.Module): # Reconstruction Loss, and find the total loss reconstruction_loss = F.mse_loss(x, output) - total_losses = reconstruction_loss + codebook_loss + self.beta*commitment_loss - print("The reconstruction loss makes up {}% of the total loss ({}/{})" - .format(reconstruction_loss*100//(total_losses), int(reconstruction_loss), int(total_losses))) + total_losses = self.alpha*reconstruction_loss + codebook_loss + self.beta*commitment_loss + + # TODO ensure the losses are balanced + #print("The reconstruction loss makes up {}% of the total loss ({}/{})" + # .format(reconstruction_loss*100//(total_losses), int(reconstruction_loss), int(total_losses))) return output, total_losses @@ -148,8 +165,8 @@ class VQVAE(nn.Module): # Hyperparams -learning_rate = 1.e-3 -num_epochs = 5 +learning_rate = 1.e-4 +num_epochs = 2 model = VQVAE().to(device) @@ -184,8 +201,6 @@ for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) - - # ------------------------------------------------- # Visualise From 58d57d0503c2a1ff45c1cb37b195ac07e134c9fd Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Wed, 11 Oct 2023 21:42:56 +1000 Subject: [PATCH 06/26] Resolved dimension issue when normalising data --- VQ_VAE_46992925/VQ_VAE | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index efe55e5e5..3cc35f39a 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -23,12 +23,11 @@ print("Torch version ", torch.__version__) # Data Loader path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" -path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" +#path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" def load_data_from_folder(name): data = [] - i = 0 for filename in tqdm(os.listdir(path+name)): # tqdm adds loading bar! @@ -36,13 +35,12 @@ def load_data_from_folder(name): image_path = os.path.join(path+name, filename) image = Image.open(image_path).convert('L') # Convert to grayscale (single channel) image = np.array(image) - - # ENSURE THERE IS CHANNEL AS WELL: - image = image.reshape((1, image.shape[0], image.shape[1])) - + + # Ensure the image has the shape (C, H, W) + image = np.expand_dims(image, axis=2) data.append(image) - if i == 50: + if i == 10: return np.array(data) i += 1 @@ -62,8 +60,8 @@ std = np.std(train_data) data_transform = transforms.Compose([transforms.ToTensor(), # Convert to PyTorch tensor transforms.Normalize(mean, std)]) -train_data = (data_transform(train_data)).to(device) -test_data = (data_transform(test_data)).to(device) +train_data = [data_transform(item) for item in train_data] +test_data = [data_transform(item) for item in test_data] # B, C, H, W train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) @@ -73,7 +71,6 @@ test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle print("> Data Loading Finished") - # ------------------------------------------------ # Model @@ -166,7 +163,7 @@ class VQVAE(nn.Module): # Hyperparams learning_rate = 1.e-4 -num_epochs = 2 +num_epochs = 1 model = VQVAE().to(device) From 2d73cb9e4cd4601fb4dc56799d788a2ff0e5dd7c Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Wed, 11 Oct 2023 23:36:28 +1000 Subject: [PATCH 07/26] Working Normalisation, Model produces visible outline --- VQ_VAE_46992925/VQ_VAE | 58 ++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 3cc35f39a..a4c11f5de 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -25,10 +25,9 @@ print("Torch version ", torch.__version__) path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" #path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" - def load_data_from_folder(name): data = [] - i = 0 + #i = 0 for filename in tqdm(os.listdir(path+name)): # tqdm adds loading bar! @@ -37,37 +36,48 @@ def load_data_from_folder(name): image = np.array(image) # Ensure the image has the shape (C, H, W) - image = np.expand_dims(image, axis=2) + image = np.expand_dims(image, axis=0) + #image = image.reshape((1, image.shape[0], image.shape[1])) + data.append(image) - if i == 10: - return np.array(data) - i += 1 + #if i == 100: + # return np.array(data) + #i += 1 return np.array(data) -# C, H, W (Numpy array) +# Loading +# B, C, H, W (Numpy array) print("> Loading Training data") -train_data = load_data_from_folder("keras_png_slices_train/") +train_data = (load_data_from_folder("keras_png_slices_train/")) print("> Loading Test data") -test_data = load_data_from_folder("keras_png_slices_test/") +test_data = (load_data_from_folder("keras_png_slices_test/")) print("The shape of the (training) data is: ", train_data.shape) +print("The shape of the (testing) data is: ", test_data.shape) -# Normalise +# Transforms and tensor mean = np.mean(train_data) std = np.std(train_data) -data_transform = transforms.Compose([transforms.ToTensor(), # Convert to PyTorch tensor - transforms.Normalize(mean, std)]) -train_data = [data_transform(item) for item in train_data] -test_data = [data_transform(item) for item in test_data] +transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) + +train_data = torch.stack([transform(item) for item in train_data]).permute(0, 2, 3, 1) +test_data = torch.stack([transform(item) for item in test_data]).permute(0, 2, 3, 1) + +print("The shape of the (training) data is: ", train_data.shape) +print("The shape of the (testing) data is: ", test_data.shape) +# DataLoaders # B, C, H, W train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True) -# TODO NORMALISE!!!!! +plt.imshow(train_data[0][0]) +plt.title("First Training image") +plt.gray() +plt.show() print("> Data Loading Finished") @@ -96,7 +106,7 @@ class VQVAE(nn.Module): # Commitment loss beta self.beta = 0.2 - self.alpha = 0.1 + self.alpha = 1.0 self.decoder = nn.Sequential( nn.ConvTranspose2d(4, 16, kernel_size=4, stride=2, padding=1), @@ -160,11 +170,11 @@ class VQVAE(nn.Module): ########################## TODO THERE IS NO RECONSTRUCTION LOSS!! +losses = [] # for visualisation # Hyperparams -learning_rate = 1.e-4 -num_epochs = 1 - +learning_rate = 1.e-3 +num_epochs = 2 model = VQVAE().to(device) print(model) @@ -196,6 +206,8 @@ for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) + + losses.append(total_losses) # To graph losses # ------------------------------------------------- @@ -203,6 +215,8 @@ for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): # C, H, W input_img = test_data[0] + +# Reshape to B, C, H, W for the model input_img = input_img.reshape(1, 1, input_img.size(-2), input_img.size(-1)) input_img = input_img.to(device, dtype=torch.float32) @@ -227,4 +241,10 @@ plt.imshow(input_img[0][0].cpu().numpy(), cmap='gray') # Assuming single-channe plt.subplot(1, 2, 2) plt.title("Model Output") plt.imshow(output[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel output +plt.show() + +plt.plot(losses) +plt.title("Losses") +plt.xlabel("Num Epochs") +plt.ylabel("Loss") plt.show() \ No newline at end of file From e24c0a70938fd79d32816eaa246490af0cb98ff0 Mon Sep 17 00:00:00 2001 From: Dru Callaghan Date: Fri, 13 Oct 2023 13:43:34 +1000 Subject: [PATCH 08/26] First successful data normalisation. Working VQ-VAE --- VQ_VAE_46992925/DataPrep.py | 241 ++++++++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 VQ_VAE_46992925/DataPrep.py diff --git a/VQ_VAE_46992925/DataPrep.py b/VQ_VAE_46992925/DataPrep.py new file mode 100644 index 000000000..b86038cde --- /dev/null +++ b/VQ_VAE_46992925/DataPrep.py @@ -0,0 +1,241 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import numpy as np +from tqdm import tqdm +from PIL import Image +from torch.utils.data import Dataset, DataLoader +from torchvision import datasets, transforms, utils +import matplotlib.pyplot as plt + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print("Torch version ", torch.__version__) + + +# ------------------------------------------------ +# Data Loader + +#path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" + + +class ImageDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith('.png')] + self.transform = transform + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, idx): + image_path = os.path.join(self.root_dir, self.image_files[idx]) + image = Image.open(image_path).convert('L') # Convert to grayscale + if self.transform: + image = self.transform(image) + return image + +print("Loading data") + +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(0.13242, 0.18826) + ]) + +train_data_dir = "keras_png_slices_train/" +test_data_dir = "keras_png_slices_test/" + + +train_dataset = ImageDataset(path+train_data_dir, transform=transform) +test_dataset = ImageDataset(path+test_data_dir, transform=transform) + +# DataLoaders +# B, C, H, W +train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True) +test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True) + +# Debugging +first_batch = 0 +for batch in train_dataloader: + first_batch = batch + break +print("Shape of first batch is: ", first_batch.shape) +print("First batch - Mean: {} Std: {}".format(torch.mean(first_batch), torch.std(first_batch))) +plt.imshow(first_batch[0][0]) +plt.title("First Training image (Normalised)") +plt.gray() +plt.show() + + +print("> Data Loading Finished") + + +# ------------------------------------------------ +# Model + +class VQVAE(nn.Module): + def __init__(self, ): + super(VQVAE, self).__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 4, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(4), + nn.ReLU(), + ) + + self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1) # TODO FC layer?? + self.embedding = nn.Embedding(num_embeddings=256, embedding_dim=2) + self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1) + + # Commitment loss beta + self.beta = 0.2 + self.alpha = 1.0 + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(4, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), + nn.Sigmoid(), + ) + + def forward(self, x): + # B, C, H, W + encoded_output = self.encoder(x) + quant_input = self.pre_quant_conv(encoded_output) + + # Quantisation + B, C, H, W = quant_input.shape + quant_input = quant_input.permute(0, 2, 3, 1) + quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) + + # Compute pairwise distances + dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) + + # Find index of nearest embedding + min_encoding_indices = torch.argmin(dist, dim=-1) + + # Select the embedding weights + quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) + + quant_input = quant_input.reshape((-1, quant_input.size(-1))) + + # Compute losses + commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) # TODO change to MSE + codebook_loss = torch.mean((quant_out - quant_input.detach())**2) + + + # Straight through gradient estimator + quant_out = quant_input + (quant_out - quant_input).detach() # Detach ~ ignored for back-prop + quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) + min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) + + # Decoding + decoder_input = self.post_quant_conv(quant_out) + output = self.decoder(decoder_input) + + # Reconstruction Loss, and find the total loss + reconstruction_loss = F.mse_loss(x, output) + total_losses = self.alpha*reconstruction_loss + codebook_loss + self.beta*commitment_loss + + # TODO ensure the losses are balanced + #print("The reconstruction loss makes up {}% of the total loss ({}/{})" + # .format(reconstruction_loss*100//(total_losses), int(reconstruction_loss), int(total_losses))) + + return output, total_losses + + +# ------------------------------------------------ +# Training + +########################## TODO THERE IS NO RECONSTRUCTION LOSS!! + +losses = [] # for visualisation + +# Hyperparams +learning_rate = 1.e-3 +num_epochs = 1 + +model = VQVAE().to(device) +print(model) + +optimiser = torch.optim.Adam(model.parameters(), learning_rate) + +for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): + model.train() + for train_batch in tqdm(train_dataloader): + images = train_batch + images = images.to(device, dtype=torch.float32) + + output, total_losses = model(images) + + optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) + total_losses.backward() # Calculate grad + optimiser.step() # Adjust weights + + # Evaluate + model.eval() + + for test_batch in tqdm(test_dataloader): + images = test_batch + + images = images.to(device, dtype=torch.float32) # (Set as float to ensure weights input are the same type) + + with torch.no_grad(): + output, total_losses = model(images) + + + print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) + + losses.append(total_losses) # To graph losses + + +# ------------------------------------------------- +# Visualise + +# C, H, W +input_img = test_dataset[0][0] + +# Reshape to B, C, H, W for the model +input_img = input_img.reshape(1, 1, input_img.size(-2), input_img.size(-1)) +input_img = input_img.to(device, dtype=torch.float32) + +# DEBUGGING Print the input image shape and show it. +print("Shape of the input img is: ", input_img.shape) +#plt.imshow(input_img[0][0].cpu().numpy()) +#plt.gray() +#plt.show() + + +with torch.no_grad(): # Ensure no gradient calculation + output, _ = model(input_img) # Forward pass through the model + +print("Shape of the output img is: ", output.shape) + +# Display input and output images +plt.figure(figsize=(10, 5)) +plt.subplot(1, 2, 1) +plt.title("Input Image") +plt.imshow(input_img[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel input + +plt.subplot(1, 2, 2) +plt.title("Model Output") +plt.imshow(output[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel output +plt.show() + +plt.plot(losses) +plt.title("Losses") +plt.xlabel("Num Epochs") +plt.ylabel("Loss") +plt.show() \ No newline at end of file From 0d5f544987a876bcac12a1078564af0265f12d22 Mon Sep 17 00:00:00 2001 From: Dru Callaghan Date: Fri, 13 Oct 2023 14:19:05 +1000 Subject: [PATCH 09/26] Fixed losses graph. Data is now loaded at beginning. --- VQ_VAE_46992925/DataPrep.py | 2 +- VQ_VAE_46992925/VQ_VAE | 34 ++++++++++++++++++---------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/VQ_VAE_46992925/DataPrep.py b/VQ_VAE_46992925/DataPrep.py index b86038cde..566bea6a8 100644 --- a/VQ_VAE_46992925/DataPrep.py +++ b/VQ_VAE_46992925/DataPrep.py @@ -165,7 +165,7 @@ def forward(self, x): # Hyperparams learning_rate = 1.e-3 -num_epochs = 1 +num_epochs = 7 model = VQVAE().to(device) print(model) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index a4c11f5de..48ae6c850 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -22,28 +22,28 @@ print("Torch version ", torch.__version__) # ------------------------------------------------ # Data Loader -path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" -#path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" +#path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" def load_data_from_folder(name): data = [] - #i = 0 + i = 0 - for filename in tqdm(os.listdir(path+name)): # tqdm adds loading bar! + for filename in tqdm([f for f in os.listdir(path+name) if f.lower().endswith('.png')]): # tqdm adds loading bar! image_path = os.path.join(path+name, filename) image = Image.open(image_path).convert('L') # Convert to grayscale (single channel) image = np.array(image) - # Ensure the image has the shape (C, H, W) - image = np.expand_dims(image, axis=0) - #image = image.reshape((1, image.shape[0], image.shape[1])) - + # Add channel + # C, H, W + image = np.expand_dims(image, axis=0) + data.append(image) - #if i == 100: - # return np.array(data) - #i += 1 + if i == 100: + return np.array(data) + i += 1 return np.array(data) @@ -61,7 +61,10 @@ print("The shape of the (testing) data is: ", test_data.shape) mean = np.mean(train_data) std = np.std(train_data) -transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=0.13242, std=0.18826) +]) train_data = torch.stack([transform(item) for item in train_data]).permute(0, 2, 3, 1) test_data = torch.stack([transform(item) for item in test_data]).permute(0, 2, 3, 1) @@ -75,7 +78,7 @@ train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuff test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True) plt.imshow(train_data[0][0]) -plt.title("First Training image") +plt.title("First Training image (Normalised)") plt.gray() plt.show() @@ -174,7 +177,7 @@ losses = [] # for visualisation # Hyperparams learning_rate = 1.e-3 -num_epochs = 2 +num_epochs = 15 model = VQVAE().to(device) print(model) @@ -207,7 +210,7 @@ for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) - losses.append(total_losses) # To graph losses + losses.append(total_losses.cpu().numpy()[0]) # To graph losses # ------------------------------------------------- @@ -215,7 +218,6 @@ for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): # C, H, W input_img = test_data[0] - # Reshape to B, C, H, W for the model input_img = input_img.reshape(1, 1, input_img.size(-2), input_img.size(-1)) input_img = input_img.to(device, dtype=torch.float32) From 0114173dae2508e248b380a5eacef3915d128571 Mon Sep 17 00:00:00 2001 From: Dru Callaghan Date: Fri, 13 Oct 2023 14:34:28 +1000 Subject: [PATCH 10/26] Replaced Sigmoid with Batchnorm, greatly improved output --- VQ_VAE_46992925/VQ_VAE | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 48ae6c850..6887d1847 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -41,9 +41,9 @@ def load_data_from_folder(name): data.append(image) - if i == 100: - return np.array(data) - i += 1 + #if i == 100: + # return np.array(data) + #i += 1 return np.array(data) @@ -98,28 +98,28 @@ class VQVAE(nn.Module): nn.Conv2d(16, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(), - nn.Conv2d(16, 4, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(4), + nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(32), nn.ReLU(), ) - self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1) # TODO FC layer?? - self.embedding = nn.Embedding(num_embeddings=256, embedding_dim=2) - self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1) + self.pre_quant_conv = nn.Conv2d(32, 32, kernel_size=1) # TODO FC layer?? + self.embedding = nn.Embedding(num_embeddings=256, embedding_dim=32) + self.post_quant_conv = nn.Conv2d(32, 32, kernel_size=1) # Commitment loss beta self.beta = 0.2 self.alpha = 1.0 self.decoder = nn.Sequential( - nn.ConvTranspose2d(4, 16, kernel_size=4, stride=2, padding=1), + nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), - nn.Sigmoid(), + nn.BatchNorm2d(1), ) def forward(self, x): @@ -210,7 +210,7 @@ for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) - losses.append(total_losses.cpu().numpy()[0]) # To graph losses + losses.append(total_losses.cpu()) # To graph losses (TODO still in tensors) # ------------------------------------------------- From 16667dadd4c95ac4fd04b84ce3b15af3f5cccf6c Mon Sep 17 00:00:00 2001 From: Dru Callaghan Date: Fri, 13 Oct 2023 21:02:05 +1000 Subject: [PATCH 11/26] Non-functional PixelCNN implementation --- VQ_VAE_46992925/VQ_VAE | 231 ++++++++++++++++++++++++++++++++-------- VQ_VAE_46992925/test.py | 50 +++++++++ 2 files changed, 239 insertions(+), 42 deletions(-) create mode 100644 VQ_VAE_46992925/test.py diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 6887d1847..07efe9be6 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -41,9 +41,9 @@ def load_data_from_folder(name): data.append(image) - #if i == 100: - # return np.array(data) - #i += 1 + if i == 50: + return np.array(data) + i += 1 return np.array(data) @@ -77,10 +77,10 @@ print("The shape of the (testing) data is: ", test_data.shape) train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True) -plt.imshow(train_data[0][0]) -plt.title("First Training image (Normalised)") -plt.gray() -plt.show() +#plt.imshow(train_data[0][0]) +#plt.title("First Training image (Normalised)") +#plt.gray() +#plt.show() print("> Data Loading Finished") @@ -91,6 +91,8 @@ class VQVAE(nn.Module): def __init__(self, ): super(VQVAE, self).__init__() + self.encoding_indices = None # Save the encoding indices to be accessed for pixel CNN + self.encoder = nn.Sequential( nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), @@ -136,10 +138,10 @@ class VQVAE(nn.Module): dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) # Find index of nearest embedding - min_encoding_indices = torch.argmin(dist, dim=-1) - + encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H + # Select the embedding weights - quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) + quant_out = torch.index_select(self.embedding.weight, 0, encoding_indices.view(-1)) quant_input = quant_input.reshape((-1, quant_input.size(-1))) @@ -151,8 +153,11 @@ class VQVAE(nn.Module): # Straight through gradient estimator quant_out = quant_input + (quant_out - quant_input).detach() # Detach ~ ignored for back-prop quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) - min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) - + + # Reshape encoding indices to 'B, H, W' + encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) + self.encoding_indices = encoding_indices + # Decoding decoder_input = self.post_quant_conv(quant_out) output = self.decoder(decoder_input) @@ -167,6 +172,9 @@ class VQVAE(nn.Module): return output, total_losses + def get_indices(self): + return self.encoding_indices + # ------------------------------------------------ # Training @@ -177,14 +185,15 @@ losses = [] # for visualisation # Hyperparams learning_rate = 1.e-3 -num_epochs = 15 +num_epochs = 2 model = VQVAE().to(device) print(model) optimiser = torch.optim.Adam(model.parameters(), learning_rate) -for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): + +for epoch_num, epoch in enumerate(range(num_epochs)): model.train() for train_batch in tqdm(train_dataloader): images = train_batch @@ -196,10 +205,11 @@ for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): total_losses.backward() # Calculate grad optimiser.step() # Adjust weights + # Evaluate model.eval() - for test_batch in tqdm(test_dataloader): + for test_batch in (test_dataloader): images = test_batch images = images.to(device, dtype=torch.float32) # (Set as float to ensure weights input are the same type) @@ -216,37 +226,174 @@ for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): # ------------------------------------------------- # Visualise -# C, H, W -input_img = test_data[0] -# Reshape to B, C, H, W for the model -input_img = input_img.reshape(1, 1, input_img.size(-2), input_img.size(-1)) -input_img = input_img.to(device, dtype=torch.float32) -# DEBUGGING Print the input image shape and show it. -print("Shape of the input img is: ", input_img.shape) -#plt.imshow(input_img[0][0].cpu().numpy()) -#plt.gray() -#plt.show() +def plot_results(num_images): + + input_imgs = test_data[0:num_images] + input_imgs = input_imgs.to(device, dtype=torch.float32) + + # DEBUGGING + print("Shape of the input img is: ", input_imgs.shape) + + with torch.no_grad(): # Ensure no gradient calculation + output_imgs, _ = model(input_imgs) # Forward pass through the model + + #Debugging + print("Shape of the output img is: ", output_imgs.shape) + + + fig, ax = plt.subplots(num_images, 2) + plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0) + + ax[0, 0].set_title("Inputs") + ax[0, 1].set_title("Reconstructions") + for i in range(num_images): + for j in range(2): + ax[i, j].axis('off') + ax[i, 0].imshow(input_imgs[i][0].cpu().numpy(), cmap='gray') + ax[i, 1].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray') + + plt.show() + + plt.plot(losses) + plt.title("Losses") + plt.xlabel("Num Epochs") + plt.ylabel("Loss") + plt.show() + +plot_results(2) + + + + +# ------------- Pixel CNN + + +class MaskedConv2d(nn.Conv2d): + + def __init__(self, num_channels, kernel_size): -with torch.no_grad(): # Ensure no gradient calculation - output, _ = model(input_img) # Forward pass through the model + super(MaskedConv2d, self).__init__(num_channels, num_channels, kernel_size=kernel_size, padding=(kernel_size//2)) -print("Shape of the output img is: ", output.shape) + #self.register_buffer('mask', torch.zeros_like(self.weight)) -# Display input and output images -plt.figure(figsize=(10, 5)) -plt.subplot(1, 2, 1) -plt.title("Input Image") -plt.imshow(input_img[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel input + k = self.kernel_size[0] + + self.weight.data[:, :, (k//2+1):, :].zero_() + self.weight.data[:, :, k//2, k//2:].zero_() + + + def forward(self, x): + + k = self.kernel_size[0] + # Type 'A' mask + self.weight.data[:, :, (k//2+1):, :].zero_() + self.weight.data[:, :, k//2, k//2:].zero_() + + out = super(MaskedConv2d, self).forward(x) + return out + + +class PixelCNN(nn.Module): + def __init__(self): + super(PixelCNN, self).__init__() + num_channels = 24 + self.num_channels = num_channels + + self.embedding = nn.Embedding(256, 24) + + # Convolutions + self.masked_conv = MaskedConv2d(num_channels=num_channels, kernel_size=6) + + self.block1 = nn.Sequential( + nn.Conv2d(num_channels, 16, 3, padding=1), + nn.BatchNorm2d(16), + nn.ReLU() + ) + self.block2 = nn.Sequential( + nn.Conv2d(16, 32, 3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU() + ) + self.final = nn.Sequential(nn.Conv2d(32, 1, 4, padding=1), nn.BatchNorm2d(1)) + + def forward(self, x): + + out = self.embedding(x).permute(0, 3, 1, 2) + out = self.masked_conv(out) + out = self.block1(out) + out = self.block2(out) + out = self.final(out) + + return out + +# ---------------- Training Pixel CNN -------------------- # + + + + +losses = [] # for visualisation +train_indices = [] # For training pix_cnn +test_indices = [] + +with torch.no_grad(): + for train_batch in (train_dataloader): + # Get the indices from the model + images = (train_batch).to(device) + _, _ = model(images) + train_indices.append(model.encoding_indices) + for test_batch in (test_dataloader): + images = test_batch.to(device) + _, _ = model(images) + test_indices.append(model.encoding_indices) + +train_indices = torch.stack(train_indices).permute(1, 0, 2, 3).squeeze(1) +test_indices = torch.stack(test_indices).permute(1, 0, 2, 3).squeeze(1) + +print("Train indices is shape: ", train_indices.shape) + +train_dataloader = torch.utils.data.DataLoader(train_indices, batch_size=128, shuffle=True) +test_dataloader = torch.utils.data.DataLoader(train_indices, batch_size=128, shuffle=True) + +# Hyperparams +learning_rate = 1e-3 +num_epochs = 1 + + +pix_cnn = PixelCNN().to(device) +print(pix_cnn) +optimiser = torch.optim.Adam(pix_cnn.parameters(), learning_rate) + + + +for epoch_num, epoch in enumerate(range(num_epochs)): + + pix_cnn.train() + for train_batch in tqdm(train_dataloader): + # Get the indices from the model + train_indices = train_batch.to(device, dtype=torch.long) + # Put indices into cnn + print("Shape of input indices is: ", train_indices.shape) + output = pix_cnn(train_indices) + + print(train_indices[0]) + print(output[0]) + + print("Shape of output is: ", output.size) + loss = F.cross_entropy(output, train_indices) + + optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) + loss.backward() # Calculate grad + optimiser.step() # Adjust weights + + + pix_cnn.eval() + for test_batch in (test_dataloader): + test_indices = test_batch.to(device, dtype=torch.long) + with torch.no_grad(): + output = pix_cnn(test_indices) + loss = F.cross_entropy(output, test_indices) -plt.subplot(1, 2, 2) -plt.title("Model Output") -plt.imshow(output[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel output -plt.show() -plt.plot(losses) -plt.title("Losses") -plt.xlabel("Num Epochs") -plt.ylabel("Loss") -plt.show() \ No newline at end of file + print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, loss)) diff --git a/VQ_VAE_46992925/test.py b/VQ_VAE_46992925/test.py new file mode 100644 index 000000000..1c5328d4f --- /dev/null +++ b/VQ_VAE_46992925/test.py @@ -0,0 +1,50 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import numpy as np +from tqdm import tqdm +from PIL import Image +import torch.utils.data +from torchvision import datasets, transforms, utils +import matplotlib.pyplot as plt + +conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3) + +weights = conv.weight.data +biases = conv.bias.data + +print("Weights: ", weights) +print("Biases: ", biases.shape) + + +class MaskedConv2d(nn.Conv2d): + + + def __init__(self, num_channels, kernel_size): + + super(MaskedConv2d, self).__init__(num_channels, num_channels, kernel_size=kernel_size, padding=(kernel_size//2)) + + #self.register_buffer('mask', torch.zeros_like(self.weight)) + + k = self.kernel_size[0] + + self.weight.data[:, :, (k//2+1):, :].zero_() + self.weight.data[:, :, k//2, k//2:].zero_() + + + def forward(self, x): + k = self.kernel_size[0] + # Type 'A' mask + self.weight.data[:, :, (k//2+1):, :].zero_() + self.weight.data[:, :, k//2, k//2:].zero_() + + out = super(MaskedConv2d, self).forward(x) + return out + + + +masked_conv = MaskedConv2d(num_channels=1, kernel_size=5) + +print("Masked weights: ", masked_conv.weight.data) \ No newline at end of file From f2b0d1b62d8248ad91c46124fc235032500c58bb Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Sun, 15 Oct 2023 21:56:49 +1000 Subject: [PATCH 12/26] Created separate classes for components, removed pre and post quant conv --- VQ_VAE_46992925/VQ_VAE | 256 +++++++++++++---------------------------- 1 file changed, 78 insertions(+), 178 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 07efe9be6..e4a7cf299 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -22,8 +22,8 @@ print("Torch version ", torch.__version__) # ------------------------------------------------ # Data Loader -#path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" -path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" +path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +#path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" def load_data_from_folder(name): data = [] @@ -87,12 +87,12 @@ print("> Data Loading Finished") # ------------------------------------------------ # Model -class VQVAE(nn.Module): + + +class Encoder(nn.Module): def __init__(self, ): - super(VQVAE, self).__init__() + super(Encoder, self).__init__() - self.encoding_indices = None # Save the encoding indices to be accessed for pixel CNN - self.encoder = nn.Sequential( nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), @@ -105,76 +105,106 @@ class VQVAE(nn.Module): nn.ReLU(), ) - self.pre_quant_conv = nn.Conv2d(32, 32, kernel_size=1) # TODO FC layer?? - self.embedding = nn.Embedding(num_embeddings=256, embedding_dim=32) - self.post_quant_conv = nn.Conv2d(32, 32, kernel_size=1) + def forward(self, x): + out = self.encoder(x) + return out + + +class Quantiser(nn.Module): + def __init__(self, num_embeddings, embedding_dim) -> None: + super(Quantiser, self).__init__() - # Commitment loss beta + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim self.beta = 0.2 - self.alpha = 1.0 - - self.decoder = nn.Sequential( - nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(1), - ) - - def forward(self, x): - # B, C, H, W - encoded_output = self.encoder(x) - quant_input = self.pre_quant_conv(encoded_output) - # Quantisation - B, C, H, W = quant_input.shape + self.embedding = self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + + + def get_encoding_indices(self, quant_input): + # Flatten quant_input = quant_input.permute(0, 2, 3, 1) quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) - + # Compute pairwise distances dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) # Find index of nearest embedding encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H + return encoding_indices + + + def forward(self, quant_input): + + B, C, H, W = quant_input.shape + + # Get the encoding indices + encoding_indices = self.get_encoding_indices(quant_input) # Select the embedding weights quant_out = torch.index_select(self.embedding.weight, 0, encoding_indices.view(-1)) + quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) + + print(quant_out.shape, quant_input.shape) - quant_input = quant_input.reshape((-1, quant_input.size(-1))) # Compute losses commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) # TODO change to MSE codebook_loss = torch.mean((quant_out - quant_input.detach())**2) - + loss = codebook_loss + self.beta*commitment_loss # Straight through gradient estimator quant_out = quant_input + (quant_out - quant_input).detach() # Detach ~ ignored for back-prop - quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) # Reshape encoding indices to 'B, H, W' + # TODO CURRENTLY MEANS NOTHING encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) - self.encoding_indices = encoding_indices + + return quant_out, loss - # Decoding - decoder_input = self.post_quant_conv(quant_out) - output = self.decoder(decoder_input) + +class Decoder(nn.Module): + def __init__(self, ) -> None: + super(Decoder, self).__init__() + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(1), + ) + + def forward(self, x): + out = self.decoder(x) + return out + + + + +class VQVAE(nn.Module): + def __init__(self, ): + super(VQVAE, self).__init__() + + self.encoder = Encoder() + self.quantiser = Quantiser(num_embeddings=265, embedding_dim=32) + self.decoder = Decoder() + + + def forward(self, x): + # B, C, H, W + quant_input = self.encoder(x) + quant_out, quant_loss = self.quantiser(quant_input) + output = self.decoder(quant_out) # Reconstruction Loss, and find the total loss reconstruction_loss = F.mse_loss(x, output) - total_losses = self.alpha*reconstruction_loss + codebook_loss + self.beta*commitment_loss - - # TODO ensure the losses are balanced - #print("The reconstruction loss makes up {}% of the total loss ({}/{})" - # .format(reconstruction_loss*100//(total_losses), int(reconstruction_loss), int(total_losses))) + total_loss = quant_loss + reconstruction_loss - return output, total_losses - - def get_indices(self): - return self.encoding_indices - + return output, total_loss # ------------------------------------------------ # Training @@ -266,134 +296,4 @@ plot_results(2) - -# ------------- Pixel CNN - - -class MaskedConv2d(nn.Conv2d): - - def __init__(self, num_channels, kernel_size): - - super(MaskedConv2d, self).__init__(num_channels, num_channels, kernel_size=kernel_size, padding=(kernel_size//2)) - - #self.register_buffer('mask', torch.zeros_like(self.weight)) - - k = self.kernel_size[0] - - self.weight.data[:, :, (k//2+1):, :].zero_() - self.weight.data[:, :, k//2, k//2:].zero_() - - - def forward(self, x): - - k = self.kernel_size[0] - # Type 'A' mask - self.weight.data[:, :, (k//2+1):, :].zero_() - self.weight.data[:, :, k//2, k//2:].zero_() - - out = super(MaskedConv2d, self).forward(x) - return out - - -class PixelCNN(nn.Module): - def __init__(self): - super(PixelCNN, self).__init__() - num_channels = 24 - self.num_channels = num_channels - - self.embedding = nn.Embedding(256, 24) - - # Convolutions - self.masked_conv = MaskedConv2d(num_channels=num_channels, kernel_size=6) - - self.block1 = nn.Sequential( - nn.Conv2d(num_channels, 16, 3, padding=1), - nn.BatchNorm2d(16), - nn.ReLU() - ) - self.block2 = nn.Sequential( - nn.Conv2d(16, 32, 3, padding=1), - nn.BatchNorm2d(32), - nn.ReLU() - ) - self.final = nn.Sequential(nn.Conv2d(32, 1, 4, padding=1), nn.BatchNorm2d(1)) - - def forward(self, x): - - out = self.embedding(x).permute(0, 3, 1, 2) - out = self.masked_conv(out) - out = self.block1(out) - out = self.block2(out) - out = self.final(out) - - return out - -# ---------------- Training Pixel CNN -------------------- # - - - - -losses = [] # for visualisation -train_indices = [] # For training pix_cnn -test_indices = [] - -with torch.no_grad(): - for train_batch in (train_dataloader): - # Get the indices from the model - images = (train_batch).to(device) - _, _ = model(images) - train_indices.append(model.encoding_indices) - for test_batch in (test_dataloader): - images = test_batch.to(device) - _, _ = model(images) - test_indices.append(model.encoding_indices) - -train_indices = torch.stack(train_indices).permute(1, 0, 2, 3).squeeze(1) -test_indices = torch.stack(test_indices).permute(1, 0, 2, 3).squeeze(1) - -print("Train indices is shape: ", train_indices.shape) - -train_dataloader = torch.utils.data.DataLoader(train_indices, batch_size=128, shuffle=True) -test_dataloader = torch.utils.data.DataLoader(train_indices, batch_size=128, shuffle=True) - -# Hyperparams -learning_rate = 1e-3 -num_epochs = 1 - - -pix_cnn = PixelCNN().to(device) -print(pix_cnn) -optimiser = torch.optim.Adam(pix_cnn.parameters(), learning_rate) - - - -for epoch_num, epoch in enumerate(range(num_epochs)): - - pix_cnn.train() - for train_batch in tqdm(train_dataloader): - # Get the indices from the model - train_indices = train_batch.to(device, dtype=torch.long) - # Put indices into cnn - print("Shape of input indices is: ", train_indices.shape) - output = pix_cnn(train_indices) - - print(train_indices[0]) - print(output[0]) - - print("Shape of output is: ", output.size) - loss = F.cross_entropy(output, train_indices) - - optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) - loss.backward() # Calculate grad - optimiser.step() # Adjust weights - - - pix_cnn.eval() - for test_batch in (test_dataloader): - test_indices = test_batch.to(device, dtype=torch.long) - with torch.no_grad(): - output = pix_cnn(test_indices) - loss = F.cross_entropy(output, test_indices) - - - print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, loss)) +# ------------- Pixel CNN \ No newline at end of file From 6ef526bb26e0239665c46497eb44c1ab0b7e3dfa Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Sun, 15 Oct 2023 22:17:06 +1000 Subject: [PATCH 13/26] Added visualisation of codebook indices --- VQ_VAE_46992925/VQ_VAE | 110 ++++++++++++++++++++++++++++++++++------- 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index e4a7cf299..9d05c8daa 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -160,7 +160,7 @@ class Quantiser(nn.Module): # TODO CURRENTLY MEANS NOTHING encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) - return quant_out, loss + return quant_out, loss, encoding_indices class Decoder(nn.Module): @@ -186,38 +186,39 @@ class Decoder(nn.Module): class VQVAE(nn.Module): - def __init__(self, ): + def __init__(self, num_embeddings, embedding_dim): super(VQVAE, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.encoder = Encoder() - self.quantiser = Quantiser(num_embeddings=265, embedding_dim=32) + self.quantiser = Quantiser(num_embeddings=num_embeddings, embedding_dim=embedding_dim) self.decoder = Decoder() def forward(self, x): # B, C, H, W quant_input = self.encoder(x) - quant_out, quant_loss = self.quantiser(quant_input) + quant_out, quant_loss, encoding_indices = self.quantiser(quant_input) output = self.decoder(quant_out) # Reconstruction Loss, and find the total loss reconstruction_loss = F.mse_loss(x, output) total_loss = quant_loss + reconstruction_loss - return output, total_loss + return output, total_loss, encoding_indices # ------------------------------------------------ # Training -########################## TODO THERE IS NO RECONSTRUCTION LOSS!! - losses = [] # for visualisation # Hyperparams learning_rate = 1.e-3 -num_epochs = 2 +num_epochs = 5 -model = VQVAE().to(device) +model = VQVAE(num_embeddings=265, embedding_dim=32).to(device) print(model) optimiser = torch.optim.Adam(model.parameters(), learning_rate) @@ -229,7 +230,7 @@ for epoch_num, epoch in enumerate(range(num_epochs)): images = train_batch images = images.to(device, dtype=torch.float32) - output, total_losses = model(images) + output, total_losses, _ = model(images) optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) total_losses.backward() # Calculate grad @@ -245,14 +246,13 @@ for epoch_num, epoch in enumerate(range(num_epochs)): images = images.to(device, dtype=torch.float32) # (Set as float to ensure weights input are the same type) with torch.no_grad(): - output, total_losses = model(images) + output, total_losses, _ = model(images) print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) losses.append(total_losses.cpu()) # To graph losses (TODO still in tensors) - # ------------------------------------------------- # Visualise @@ -266,23 +266,27 @@ def plot_results(num_images): print("Shape of the input img is: ", input_imgs.shape) with torch.no_grad(): # Ensure no gradient calculation - output_imgs, _ = model(input_imgs) # Forward pass through the model - + output_imgs, _, encoding_indices = model(input_imgs) + + #Debugging print("Shape of the output img is: ", output_imgs.shape) + print("Enc indices shape is: ", encoding_indices.shape) - fig, ax = plt.subplots(num_images, 2) + fig, ax = plt.subplots(num_images, 3) plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0) ax[0, 0].set_title("Inputs") - ax[0, 1].set_title("Reconstructions") + ax[0, 1].set_title("CodeBook Indices") + ax[0, 2].set_title("Reconstruction") for i in range(num_images): - for j in range(2): + for j in range(3): ax[i, j].axis('off') ax[i, 0].imshow(input_imgs[i][0].cpu().numpy(), cmap='gray') - ax[i, 1].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray') + ax[i, 1].imshow(encoding_indices[i].cpu().numpy()) + ax[i, 2].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray') plt.show() @@ -296,4 +300,72 @@ plot_results(2) -# ------------- Pixel CNN \ No newline at end of file +# ------------- Pixel CNN +# Define the PixelConvLayer class in PyTorch +class PixelConvLayer(nn.Module): + def __init__(self, mask_type, in_channels, out_channels, kernel_size, stride, padding): + super(PixelConvLayer, self).__init__() + self.mask_type = mask_type + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + self.build_mask() + + def build_mask(self): + kernel_shape = self.conv.weight.shape + self.mask = torch.zeros(kernel_shape) + self.mask[:, :, :kernel_shape[2] // 2, :] = 1.0 + self.mask[:, :, kernel_shape[2] // 2, :kernel_shape[3] // 2] = 1.0 + if self.mask_type == "B": + self.mask[:, :, kernel_shape[2] // 2, kernel_shape[3] // 2] = 1.0 + + def forward(self, inputs): + with torch.no_grad(): + self.conv.weight.data *= self.mask + return self.conv(inputs) + +# Define the ResidualBlock class in PyTorch +class ResidualBlock(nn.Module): + def __init__(self, in_channels, filters): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, filters, kernel_size=1) + self.pixel_conv = PixelConvLayer(mask_type="B", in_channels=filters, out_channels=filters // 2, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(filters // 2, filters, kernel_size=1) + + def forward(self, inputs): + x = self.conv1(inputs) + x = self.pixel_conv(x) + x = self.conv2(x) + return inputs + x + +# Define your PyTorch model +class PixelCNN(nn.Module): + def __init__(self, pixelcnn_input_shape, num_residual_blocks, num_pixelcnn_layers, num_embeddings): + super(PixelCNN, self).__init__() + self.pixel_conv_initial = PixelConvLayer(mask_type="A", in_channels=num_embeddings, out_channels=128, kernel_size=7, stride=1, padding=3) + self.residual_blocks = nn.ModuleList([ResidualBlock(in_channels=128, filters=128) for _ in range(num_residual_blocks)]) + self.pixel_conv_layers = nn.ModuleList([PixelConvLayer(mask_type="B", in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) for _ in range(num_pixelcnn_layers)]) + self.out_conv = nn.Conv2d(in_channels=128, out_channels=num_embeddings, kernel_size=1, stride=1, padding=0) + + def forward(self, inputs): + x = self.pixel_conv_initial(inputs) + for block in self.residual_blocks: + x = block(x) + for layer in self.pixel_conv_layers: + x = layer(x) + out = self.out_conv(x) + return out + +# Create an instance of the PixelCNN model +batch_size = 128 +num_embeddings = 256 +height = 256 +width = 256 + +pixelcnn_input_shape = (batch_size, num_embeddings, height, width) # Define your input shape +num_residual_blocks = 5 # Define the number of residual blocks +num_pixelcnn_layers = 5 # Define the number of PixelConvLayers +num_embeddings = model.num_embeddings # Define the number of embeddings +batch_size = 1 + +cnn_model = PixelCNN(pixelcnn_input_shape, num_residual_blocks, num_pixelcnn_layers, num_embeddings) +print(cnn_model) + From b0ad3d5f8db9bfe93851bad86c6f7f13bf7ee0d5 Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Mon, 16 Oct 2023 15:25:27 +1000 Subject: [PATCH 14/26] Working PixelCNN. Generates 1 table of floats using MSE loss --- VQ_VAE_46992925/VQ_VAE | 112 ++++++++++++++++++++++++++++++++++------- 1 file changed, 95 insertions(+), 17 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 9d05c8daa..b75e20937 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -41,7 +41,7 @@ def load_data_from_folder(name): data.append(image) - if i == 50: + if i == 25: return np.array(data) i += 1 @@ -132,7 +132,7 @@ class Quantiser(nn.Module): # Find index of nearest embedding encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H return encoding_indices - + def forward(self, quant_input): @@ -216,9 +216,14 @@ losses = [] # for visualisation # Hyperparams learning_rate = 1.e-3 -num_epochs = 5 +num_epochs = 6 + +num_embeddings = 256 +embedding_dim = 32 + + -model = VQVAE(num_embeddings=265, embedding_dim=32).to(device) +model = VQVAE(num_embeddings=num_embeddings, embedding_dim=embedding_dim).to(device) print(model) optimiser = torch.optim.Adam(model.parameters(), learning_rate) @@ -319,8 +324,11 @@ class PixelConvLayer(nn.Module): def forward(self, inputs): with torch.no_grad(): + #print("Weights: ", self.conv.weight.data.cpu().numpy().shape, "Mask: ", self.mask.cpu().numpy().shape, + # "inputs: ", inputs.cpu().numpy().shape) self.conv.weight.data *= self.mask - return self.conv(inputs) + out = self.conv(inputs) + return out # Define the ResidualBlock class in PyTorch class ResidualBlock(nn.Module): @@ -333,39 +341,109 @@ class ResidualBlock(nn.Module): def forward(self, inputs): x = self.conv1(inputs) x = self.pixel_conv(x) - x = self.conv2(x) + x = self.conv2(x) return inputs + x # Define your PyTorch model class PixelCNN(nn.Module): - def __init__(self, pixelcnn_input_shape, num_residual_blocks, num_pixelcnn_layers, num_embeddings): - super(PixelCNN, self).__init__() - self.pixel_conv_initial = PixelConvLayer(mask_type="A", in_channels=num_embeddings, out_channels=128, kernel_size=7, stride=1, padding=3) + def __init__(self, num_residual_blocks, num_pixelcnn_layers): + super(PixelCNN, self).__init__() + self.pixel_conv_initial = PixelConvLayer(mask_type="A", in_channels=1, out_channels=128, kernel_size=7, stride=1, padding=3) self.residual_blocks = nn.ModuleList([ResidualBlock(in_channels=128, filters=128) for _ in range(num_residual_blocks)]) self.pixel_conv_layers = nn.ModuleList([PixelConvLayer(mask_type="B", in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) for _ in range(num_pixelcnn_layers)]) - self.out_conv = nn.Conv2d(in_channels=128, out_channels=num_embeddings, kernel_size=1, stride=1, padding=0) + self.out_conv = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0) def forward(self, inputs): - x = self.pixel_conv_initial(inputs) + x = self.pixel_conv_initial(inputs) # The type A mask + print("Shape returned from masked convolution is: ", x.detach().cpu().numpy().shape) + + # Creates structure. Residual blocks have type B mask for block in self.residual_blocks: x = block(x) for layer in self.pixel_conv_layers: x = layer(x) out = self.out_conv(x) + print("Final return of cnnmodel is: ", out.detach().cpu().numpy().shape) return out -# Create an instance of the PixelCNN model +# PixelCNN model batch_size = 128 -num_embeddings = 256 height = 256 width = 256 -pixelcnn_input_shape = (batch_size, num_embeddings, height, width) # Define your input shape num_residual_blocks = 5 # Define the number of residual blocks num_pixelcnn_layers = 5 # Define the number of PixelConvLayers -num_embeddings = model.num_embeddings # Define the number of embeddings -batch_size = 1 -cnn_model = PixelCNN(pixelcnn_input_shape, num_residual_blocks, num_pixelcnn_layers, num_embeddings) +# Takes codebook indices (of shape B, C, H, W), where C=1, H=W=256 +cnn_model = PixelCNN(num_residual_blocks, num_pixelcnn_layers).to(device) print(cnn_model) +loss = None + + +# Data prep for cnn model +encoder = model.__getattr__("encoder") +quantiser = model.__getattr__("quantiser") +decoder = model.__getattr__("decoder") + +optimiser = torch.optim.Adam(cnn_model.parameters(), learning_rate) + +for epoch_num, epoch in enumerate(range(num_epochs)): + + #Train + cnn_model.train() + + + for train_batch in tqdm(train_dataloader): + with torch.no_grad(): + encoder_output = encoder(train_batch) + _, _, encoding_indices = quantiser(encoder_output) + encoding_indices = encoding_indices.reshape(encoding_indices.size(0), 1, encoding_indices.size(1), encoding_indices.size(2)) + encoding_indices = encoding_indices.to(device) + encoding_indices = encoding_indices.float() + + output = cnn_model(encoding_indices) + + + + # TODO One hot encode to use CROSS ENTROPY LOSS. ASK TUTOR TMR + criterion = nn.MSELoss() + loss = criterion(output, encoding_indices) + + optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) + loss.backward() # Calculate grad + optimiser.step() # Adjust weights + + # Evaluate + cnn_model.eval() + + + for test_batch in (test_dataloader): + + + with torch.no_grad(): + encoder_output = encoder(test_batch) + _, _, encoding_indices = quantiser(encoder_output) + encoding_indices = encoding_indices.reshape(encoding_indices.size(0), 1, encoding_indices.size(1), encoding_indices.size(2)) + print("Encoding Indices shape is: ", encoding_indices.shape) + encoding_indices = encoding_indices.to(device) + encoding_indices = encoding_indices.float() + # Normalise to between 0 and 1 + + output = cnn_model(encoding_indices) + criterion = nn.MSELoss() + loss = criterion(output, encoding_indices) + + + + + print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, loss)) + +# Show one image +fig, ax = plt.subplots(1, 2) +ax[0].imshow(encoding_indices[0][0].long()) +ax[1].imshow(output[0][0].long()) +plt.show() + + + From 462e5fac2c58fa2ce0b1f34cebaf9baafa9a74f8 Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Mon, 16 Oct 2023 22:11:32 +1000 Subject: [PATCH 15/26] New PixCNN, non-functional img generation --- VQ_VAE_46992925/VQ_VAE | 299 +++++++++++++++++++++++++---------------- 1 file changed, 184 insertions(+), 115 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index b75e20937..a2d5cd02b 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -133,21 +133,20 @@ class Quantiser(nn.Module): encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H return encoding_indices + def output_from_indices(self, indices, output_shape): + quant_out = torch.index_select(self.embedding.weight, 0, indices.view(-1)) + quant_out = quant_out.reshape(output_shape).permute(0, 3, 1, 2) + return quant_out def forward(self, quant_input): - B, C, H, W = quant_input.shape - # Get the encoding indices encoding_indices = self.get_encoding_indices(quant_input) - # Select the embedding weights - quant_out = torch.index_select(self.embedding.weight, 0, encoding_indices.view(-1)) - quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) + quant_out = self.output_from_indices(encoding_indices, quant_input.shape) print(quant_out.shape, quant_input.shape) - # Compute losses commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) # TODO change to MSE codebook_loss = torch.mean((quant_out - quant_input.detach())**2) @@ -209,6 +208,11 @@ class VQVAE(nn.Module): return output, total_loss, encoding_indices + @torch.no_grad() + def img_from_indices(self, indices): + quant_out = self.quantiser.output_from_indices(indices, (26, 1, 32, 32)) + return self.decoder(quant_out) + # ------------------------------------------------ # Training @@ -216,7 +220,7 @@ losses = [] # for visualisation # Hyperparams learning_rate = 1.e-3 -num_epochs = 6 +num_epochs = 3 num_embeddings = 256 embedding_dim = 32 @@ -307,143 +311,208 @@ plot_results(2) # ------------- Pixel CNN # Define the PixelConvLayer class in PyTorch -class PixelConvLayer(nn.Module): - def __init__(self, mask_type, in_channels, out_channels, kernel_size, stride, padding): - super(PixelConvLayer, self).__init__() - self.mask_type = mask_type - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) - self.build_mask() - - def build_mask(self): - kernel_shape = self.conv.weight.shape - self.mask = torch.zeros(kernel_shape) - self.mask[:, :, :kernel_shape[2] // 2, :] = 1.0 - self.mask[:, :, kernel_shape[2] // 2, :kernel_shape[3] // 2] = 1.0 - if self.mask_type == "B": - self.mask[:, :, kernel_shape[2] // 2, kernel_shape[3] // 2] = 1.0 - - def forward(self, inputs): - with torch.no_grad(): - #print("Weights: ", self.conv.weight.data.cpu().numpy().shape, "Mask: ", self.mask.cpu().numpy().shape, - # "inputs: ", inputs.cpu().numpy().shape) - self.conv.weight.data *= self.mask - out = self.conv(inputs) - return out -# Define the ResidualBlock class in PyTorch -class ResidualBlock(nn.Module): - def __init__(self, in_channels, filters): - super(ResidualBlock, self).__init__() - self.conv1 = nn.Conv2d(in_channels, filters, kernel_size=1) - self.pixel_conv = PixelConvLayer(mask_type="B", in_channels=filters, out_channels=filters // 2, kernel_size=3, stride=1, padding=1) - self.conv2 = nn.Conv2d(filters // 2, filters, kernel_size=1) - - def forward(self, inputs): - x = self.conv1(inputs) - x = self.pixel_conv(x) - x = self.conv2(x) - return inputs + x - -# Define your PyTorch model -class PixelCNN(nn.Module): - def __init__(self, num_residual_blocks, num_pixelcnn_layers): - super(PixelCNN, self).__init__() - self.pixel_conv_initial = PixelConvLayer(mask_type="A", in_channels=1, out_channels=128, kernel_size=7, stride=1, padding=3) - self.residual_blocks = nn.ModuleList([ResidualBlock(in_channels=128, filters=128) for _ in range(num_residual_blocks)]) - self.pixel_conv_layers = nn.ModuleList([PixelConvLayer(mask_type="B", in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) for _ in range(num_pixelcnn_layers)]) - self.out_conv = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0) - - def forward(self, inputs): - x = self.pixel_conv_initial(inputs) # The type A mask - print("Shape returned from masked convolution is: ", x.detach().cpu().numpy().shape) - - # Creates structure. Residual blocks have type B mask - for block in self.residual_blocks: - x = block(x) - for layer in self.pixel_conv_layers: - x = layer(x) - out = self.out_conv(x) - print("Final return of cnnmodel is: ", out.detach().cpu().numpy().shape) - return out +class MaskedConvolution(nn.Module): + + def __init__(self, in_channels, out_channels, mask, dilation=1): + + super(MaskedConvolution, self).__init__() + kernel_size = (mask.shape[0], mask.shape[1]) + padding = ([dilation*(kernel_size[i] - 1) // 2 for i in range(2)]) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) + + # Mask as buffer (must be moved with devices) + self.register_buffer('mask', mask[None,None]) + + def forward(self, x): + self.conv.weight.data *= self.mask # Set all following weights to 0 + return self.conv(x) + -# PixelCNN model -batch_size = 128 -height = 256 -width = 256 +class VerticalConv(MaskedConvolution): + # Masks all pixels below + def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): + mask = torch.ones(kernel_size, kernel_size) + mask[kernel_size//2+1:,:] = 0 + # For the first convolution, mask center row + if mask_center: + mask[kernel_size//2,:] = 0 -num_residual_blocks = 5 # Define the number of residual blocks -num_pixelcnn_layers = 5 # Define the number of PixelConvLayers + super().__init__(in_channels, out_channels, mask, dilation=dilation) -# Takes codebook indices (of shape B, C, H, W), where C=1, H=W=256 -cnn_model = PixelCNN(num_residual_blocks, num_pixelcnn_layers).to(device) -print(cnn_model) +class HorizontalConv(MaskedConvolution): -loss = None + def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): + # Mask out all pixels on the left. (Note that kernel has a size of 1 + # in height because we only look at the pixel in the same row) + mask = torch.ones(1,kernel_size) + mask[0,kernel_size//2+1:] = 0 + # For first convolution, mask center pixel + if mask_center: + mask[0,kernel_size//2] = 0 -# Data prep for cnn model + super().__init__(in_channels, out_channels, mask, dilation=dilation) + + +class GatedMaskedConv(nn.Module): + + def __init__(self, in_channels, dilation=1): + + super(GatedMaskedConv, self).__init__() + self.conv_vert = VerticalConv(in_channels, out_channels=2*in_channels, dilation=dilation) + self.conv_horiz = HorizontalConv(in_channels, out_channels=2*in_channels, dilation=dilation) + self.conv_vert_to_horiz = nn.Conv2d(2*in_channels, 2*in_channels, kernel_size=1, padding=0) + self.conv_horiz_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) + + def forward(self, v_stack, h_stack): + # Vertical stack (left) + v_stack_feat = self.conv_vert(v_stack) + v_val, v_gate = v_stack_feat.chunk(2, dim=1) + v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate) + + # Horizontal stack (right) + h_stack_feat = self.conv_horiz(h_stack) + h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat) + h_val, h_gate = h_stack_feat.chunk(2, dim=1) + h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate) + h_stack_out = self.conv_horiz_1x1(h_stack_feat) + h_stack_out = h_stack_out + h_stack + + return v_stack_out, h_stack_out + +class PixelCNN(nn.Module): + + def __init__(self, in_channels, hidden_channels): + super().__init__() + + # Initial convolutions skipping the center pixel + self.conv_vstack = VerticalConv(in_channels, hidden_channels, mask_center=True) + self.conv_hstack = HorizontalConv(in_channels, hidden_channels, mask_center=True) + # Convolution block of PixelCNN. Uses dilation instead of downscaling + self.conv_layers = nn.ModuleList([ + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=2), + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=4), + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=2), + GatedMaskedConv(hidden_channels) + ]) + # Output classification convolution (1x1) + # The output channels should be in_channels*number of embeddings to learn continuous space and calc. CrossEntropyLoss + self.conv_out = nn.Conv2d(hidden_channels, in_channels*num_embeddings, kernel_size=1, padding=0) + + + def forward(self, x): + # Scale input from 0 to 255 to -1 to 1 + x = (x.float() / 255.0) * 2 - 1 + + # Initial convolutions + v_stack = self.conv_vstack(x) + h_stack = self.conv_hstack(x) + # Gated Convolutions + for layer in self.conv_layers: + v_stack, h_stack = layer(v_stack, h_stack) + # 1x1 classification convolution + # Apply ELU (exponential activation function) before 1x1 convolution for non-linearity on residual connection + out = self.conv_out(F.elu(h_stack)) + + # Output dimensions: [Batch, Classes, Channels, Height, Width] (classes = num_embeddings) + out = out.reshape(out.shape[0], num_embeddings, out.shape[1]//256, out.shape[2], out.shape[3]) + return out + + """Indices shape should be in form B C H W + Pixels to fill should be marked with -1""" + @torch.no_grad() + def sample(self, ind_shape, ind=None): + # Create tensor of indices (all -1) + if ind is None: + ind = torch.zeros(ind_shape, dtype=torch.long).to(device) - 1 + # Generation loop (iterating through pixels across channels) + for h in range(ind_shape[2]): # Heights + for w in range(ind_shape[3]): # Widths + for c in range(ind_shape[1]): # Channels + # Skip if not to be filled (-1) + if (ind[:,c,h,w] != -1).all().item(): + continue + # Only have to input upper half of ind (rest are masked anyway) + pred = self.forward(ind[:,:,:h+1,:]) + probs = F.softmax(pred[:,:,c,h,w], dim=-1) + ind[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) + return ind + + + +cnn_model = PixelCNN(in_channels=1, hidden_channels=128) +optimiser = torch.optim.Adam(cnn_model.parameters(), learning_rate) + +# For getting codebook indices encoder = model.__getattr__("encoder") quantiser = model.__getattr__("quantiser") decoder = model.__getattr__("decoder") -optimiser = torch.optim.Adam(cnn_model.parameters(), learning_rate) - -for epoch_num, epoch in enumerate(range(num_epochs)): +for epoch in range(num_epochs): - #Train cnn_model.train() - - for train_batch in tqdm(train_dataloader): + for train_batch in train_dataloader: + + # Get the quantised outputs with torch.no_grad(): encoder_output = encoder(train_batch) - _, _, encoding_indices = quantiser(encoder_output) - encoding_indices = encoding_indices.reshape(encoding_indices.size(0), 1, encoding_indices.size(1), encoding_indices.size(2)) - encoding_indices = encoding_indices.to(device) - encoding_indices = encoding_indices.float() - - output = cnn_model(encoding_indices) + _, _, indices = quantiser(encoder_output) + indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) + + output = cnn_model(indices) - - - # TODO One hot encode to use CROSS ENTROPY LOSS. ASK TUTOR TMR - criterion = nn.MSELoss() - loss = criterion(output, encoding_indices) + # Compute loss + nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood + bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension + loss = bpd.mean() optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) - loss.backward() # Calculate grad + loss.backward() # Calculate grad optimiser.step() # Adjust weights - # Evaluate cnn_model.eval() - - for test_batch in (test_dataloader): - - + for test_batch in test_dataloader: + # Get the quantised outputs with torch.no_grad(): encoder_output = encoder(test_batch) - _, _, encoding_indices = quantiser(encoder_output) - encoding_indices = encoding_indices.reshape(encoding_indices.size(0), 1, encoding_indices.size(1), encoding_indices.size(2)) - print("Encoding Indices shape is: ", encoding_indices.shape) - encoding_indices = encoding_indices.to(device) - encoding_indices = encoding_indices.float() - # Normalise to between 0 and 1 - - output = cnn_model(encoding_indices) - criterion = nn.MSELoss() - loss = criterion(output, encoding_indices) + _, _, indices = quantiser(encoder_output) + indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) + + output = cnn_model(indices) + print("Indices is shape: ", indices.detach().cpu().numpy().shape) + print("Output is shape: ", output.detach().cpu().numpy().shape) - - + + # Compute loss + nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood + bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension + loss = bpd.mean() + print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, loss)) - + # Show one image -fig, ax = plt.subplots(1, 2) -ax[0].imshow(encoding_indices[0][0].long()) -ax[1].imshow(output[0][0].long()) -plt.show() +print(" > Showing Images") +print("Real indices shape: ", indices.detach().cpu().numpy().shape) +gen_indices = cnn_model.sample((1, 1, 16, 16)) +print("Gen indices shape: ", gen_indices.detach().cpu().numpy().shape) + +fig, ax = plt.subplots(2, 2) +ax[0, 0].set_title("Real Indices") +ax[0, 0].imshow(indices[0][0].long()) +ax[0, 1].set_title("Real Decoded") +ax[0, 1].imshow(model.img_from_indices(indices)[0][0]) +ax[1, 0].set_title("Generated Indices") +ax[1, 0].imshow(gen_indices[0][0]) +ax[1, 1].set_title("Generated Image") +ax[1, 1].imshow(model.img_from_indices(gen_indices)[0][0]) +plt.show() \ No newline at end of file From 7237f23c2bec0959e2225262f4544650563fda47 Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Mon, 16 Oct 2023 22:37:03 +1000 Subject: [PATCH 16/26] PixCNN implemented, unclear images --- VQ_VAE_46992925/VQ_VAE | 47 ++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index a2d5cd02b..68c8791ba 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -209,8 +209,8 @@ class VQVAE(nn.Module): return output, total_loss, encoding_indices @torch.no_grad() - def img_from_indices(self, indices): - quant_out = self.quantiser.output_from_indices(indices, (26, 1, 32, 32)) + def img_from_indices(self, indices, quant_out_shape): + quant_out = self.quantiser.output_from_indices(indices, quant_out_shape) # Output is currently 32*32 img with 32 channels return self.decoder(quant_out) # ------------------------------------------------ @@ -220,7 +220,7 @@ losses = [] # for visualisation # Hyperparams learning_rate = 1.e-3 -num_epochs = 3 +num_epochs = 15 num_embeddings = 256 embedding_dim = 32 @@ -452,7 +452,7 @@ encoder = model.__getattr__("encoder") quantiser = model.__getattr__("quantiser") decoder = model.__getattr__("decoder") -for epoch in range(num_epochs): +for epoch_num, epoch in enumerate(range(num_epochs)): cnn_model.train() @@ -496,23 +496,26 @@ for epoch in range(num_epochs): loss = bpd.mean() print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, loss)) - -# Show one image -print(" > Showing Images") -print("Real indices shape: ", indices.detach().cpu().numpy().shape) -gen_indices = cnn_model.sample((1, 1, 16, 16)) -print("Gen indices shape: ", gen_indices.detach().cpu().numpy().shape) - -fig, ax = plt.subplots(2, 2) -ax[0, 0].set_title("Real Indices") -ax[0, 0].imshow(indices[0][0].long()) -ax[0, 1].set_title("Real Decoded") -ax[0, 1].imshow(model.img_from_indices(indices)[0][0]) +with torch.no_grad(): + # Show one image + print(" > Showing Images") + print("Real indices shape: ", indices.detach().cpu().numpy().shape) + gen_indices = cnn_model.sample((1, 1, 32, 32)) + print("Gen indices shape: ", gen_indices.detach().cpu().numpy().shape) - -ax[1, 0].set_title("Generated Indices") -ax[1, 0].imshow(gen_indices[0][0]) -ax[1, 1].set_title("Generated Image") -ax[1, 1].imshow(model.img_from_indices(gen_indices)[0][0]) -plt.show() \ No newline at end of file + fig, ax = plt.subplots(2, 2) + + for a in ax.flatten(): + a.axis('off') + + ax[0, 0].set_title("Real Indices") + ax[0, 0].imshow(indices[0][0].long(), cmap='gray') + ax[0, 1].set_title("Real Decoded") + ax[0, 1].imshow(model.img_from_indices(indices, quant_out_shape=(26, 32, 32, 32))[0][0], cmap='gray') + + ax[1, 0].set_title("Generated Indices") + ax[1, 0].imshow(gen_indices[0][0], cmap='gray') + ax[1, 1].set_title("Generated Image") + ax[1, 1].imshow(model.img_from_indices(gen_indices, (1, 32, 32, 32))[0][0], cmap='gray') + plt.show() \ No newline at end of file From 4e5ae1a882d9c19cbc8eb5e4fc26c65b8e4ba760 Mon Sep 17 00:00:00 2001 From: Dru Callaghan Date: Tue, 17 Oct 2023 14:52:57 +1000 Subject: [PATCH 17/26] First working models for GPU --- VQ_VAE_46992925/VQ_VAE | 53 ++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE index 68c8791ba..cc8a02a32 100644 --- a/VQ_VAE_46992925/VQ_VAE +++ b/VQ_VAE_46992925/VQ_VAE @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import numpy as np -from tqdm import tqdm +from tqdm.auto import tqdm from PIL import Image import torch.utils.data from torchvision import datasets, transforms, utils @@ -22,14 +22,15 @@ print("Torch version ", torch.__version__) # ------------------------------------------------ # Data Loader -path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" -#path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" +#path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" def load_data_from_folder(name): data = [] - i = 0 - - for filename in tqdm([f for f in os.listdir(path+name) if f.lower().endswith('.png')]): # tqdm adds loading bar! + #i = 0 + list_files = [f for f in os.listdir(path+name) if f.lower().endswith('.png')] + + for filename in tqdm(list_files): # tqdm adds loading bar image_path = os.path.join(path+name, filename) image = Image.open(image_path).convert('L') # Convert to grayscale (single channel) @@ -41,9 +42,9 @@ def load_data_from_folder(name): data.append(image) - if i == 25: - return np.array(data) - i += 1 + #if i == 25: + # return np.array(data) + #i += 1 return np.array(data) @@ -145,7 +146,7 @@ class Quantiser(nn.Module): quant_out = self.output_from_indices(encoding_indices, quant_input.shape) - print(quant_out.shape, quant_input.shape) + #print(quant_out.shape, quant_input.shape) # Compute losses commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) # TODO change to MSE @@ -220,13 +221,11 @@ losses = [] # for visualisation # Hyperparams learning_rate = 1.e-3 -num_epochs = 15 +num_epochs = 30 num_embeddings = 256 embedding_dim = 32 - - model = VQVAE(num_embeddings=num_embeddings, embedding_dim=embedding_dim).to(device) print(model) @@ -325,7 +324,7 @@ class MaskedConvolution(nn.Module): self.register_buffer('mask', mask[None,None]) def forward(self, x): - self.conv.weight.data *= self.mask # Set all following weights to 0 + self.conv.weight.data = self.conv.weight.data.to(device) * self.mask.to(device) # Set all following weights to 0 (make sure it is in GPU) return self.conv(x) @@ -443,8 +442,12 @@ class PixelCNN(nn.Module): return ind +# ------------------------------- +# Training PixCNN + +num_epochs = 30 -cnn_model = PixelCNN(in_channels=1, hidden_channels=128) +cnn_model = PixelCNN(in_channels=1, hidden_channels=128).to(device) optimiser = torch.optim.Adam(cnn_model.parameters(), learning_rate) # For getting codebook indices @@ -460,7 +463,7 @@ for epoch_num, epoch in enumerate(range(num_epochs)): # Get the quantised outputs with torch.no_grad(): - encoder_output = encoder(train_batch) + encoder_output = encoder(train_batch.to(device)) _, _, indices = quantiser(encoder_output) indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) @@ -480,14 +483,14 @@ for epoch_num, epoch in enumerate(range(num_epochs)): for test_batch in test_dataloader: # Get the quantised outputs with torch.no_grad(): - encoder_output = encoder(test_batch) + encoder_output = encoder(test_batch.to(device)) _, _, indices = quantiser(encoder_output) indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) output = cnn_model(indices) - print("Indices is shape: ", indices.detach().cpu().numpy().shape) - print("Output is shape: ", output.detach().cpu().numpy().shape) + #print("Indices is shape: ", indices.detach().cpu().numpy().shape) + #print("Output is shape: ", output.detach().cpu().numpy().shape) # Compute loss @@ -500,9 +503,9 @@ for epoch_num, epoch in enumerate(range(num_epochs)): with torch.no_grad(): # Show one image print(" > Showing Images") - print("Real indices shape: ", indices.detach().cpu().numpy().shape) + #print("Real indices shape: ", indices.detach().cpu().numpy().shape) gen_indices = cnn_model.sample((1, 1, 32, 32)) - print("Gen indices shape: ", gen_indices.detach().cpu().numpy().shape) + #print("Gen indices shape: ", gen_indices.detach().cpu().numpy().shape) fig, ax = plt.subplots(2, 2) @@ -510,12 +513,12 @@ with torch.no_grad(): a.axis('off') ax[0, 0].set_title("Real Indices") - ax[0, 0].imshow(indices[0][0].long(), cmap='gray') + ax[0, 0].imshow(indices[0][0].long().cpu().numpy(), cmap='gray') ax[0, 1].set_title("Real Decoded") - ax[0, 1].imshow(model.img_from_indices(indices, quant_out_shape=(26, 32, 32, 32))[0][0], cmap='gray') + ax[0, 1].imshow(model.img_from_indices(indices, quant_out_shape=(26, 32, 32, 32))[0][0].cpu().numpy(), cmap='gray') ax[1, 0].set_title("Generated Indices") - ax[1, 0].imshow(gen_indices[0][0], cmap='gray') + ax[1, 0].imshow(gen_indices[0][0].cpu().numpy(), cmap='gray') ax[1, 1].set_title("Generated Image") - ax[1, 1].imshow(model.img_from_indices(gen_indices, (1, 32, 32, 32))[0][0], cmap='gray') + ax[1, 1].imshow(model.img_from_indices(gen_indices, (1, 32, 32, 32))[0][0].cpu().numpy(), cmap='gray') plt.show() \ No newline at end of file From fa71f71a15899a16351f6bf572816cad6eb1677c Mon Sep 17 00:00:00 2001 From: Dru Callaghan Date: Thu, 19 Oct 2023 13:12:24 +1000 Subject: [PATCH 18/26] Working dataset, modules and train --- Modules/PixelCNN.py | 136 +++++++++ Modules/VQ_VAE.py | 129 +++++++++ VQ_VAE_46992925/{VQ_VAE => VQ_VAE_original} | 0 __pycache__/dataset.cpython-39.pyc | Bin 0 -> 2534 bytes __pycache__/modules.cpython-39.pyc | Bin 0 -> 8167 bytes dataset.py | 70 +++++ modules.py | 268 +++++++++++++++++ train.py | 300 ++++++++++++++++++++ 8 files changed, 903 insertions(+) create mode 100644 Modules/PixelCNN.py create mode 100644 Modules/VQ_VAE.py rename VQ_VAE_46992925/{VQ_VAE => VQ_VAE_original} (100%) create mode 100644 __pycache__/dataset.cpython-39.pyc create mode 100644 __pycache__/modules.cpython-39.pyc create mode 100644 dataset.py create mode 100644 modules.py create mode 100644 train.py diff --git a/Modules/PixelCNN.py b/Modules/PixelCNN.py new file mode 100644 index 000000000..9fbf984d2 --- /dev/null +++ b/Modules/PixelCNN.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + + +"""Autoregressive PixelCNN model""" +class PixelCNN(nn.Module): + + def __init__(self, in_channels, hidden_channels, num_embeddings): + super(PixelCNN, self).__init__() + # Equal to the number of embeddings in the VQVAE + self.num_embeddings = num_embeddings + # Initial convolutions skipping the center pixel + self.conv_vstack = VerticalConv(in_channels, hidden_channels, mask_center=True) + self.conv_hstack = HorizontalConv(in_channels, hidden_channels, mask_center=True) + # Convolution block of PixelCNN. Uses dilation instead of downscaling + self.conv_layers = nn.ModuleList([ + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=2), + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=4), + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=2), + GatedMaskedConv(hidden_channels) + ]) + # Output classification convolution (1x1) + # The output channels should be in_channels*number of embeddings to learn continuous space and calc. CrossEntropyLoss + self.conv_out = nn.Conv2d(hidden_channels, in_channels*self.num_embeddings, kernel_size=1, padding=0) + + + def forward(self, x): + # Scale input from 0 to 255 to -1 to 1 + x = (x.float() / 255.0) * 2 - 1 + + # Initial convolutions + v_stack = self.conv_vstack(x) + h_stack = self.conv_hstack(x) + # Gated Convolutions + for layer in self.conv_layers: + v_stack, h_stack = layer(v_stack, h_stack) + # 1x1 classification convolution + # Apply ELU (exponential activation function) before 1x1 convolution for non-linearity on residual connection + out = self.conv_out(F.elu(h_stack)) + + # Output dimensions: [Batch, Classes, Channels, Height, Width] (classes = num_embeddings) + out = out.reshape(out.shape[0], self.num_embeddings, out.shape[1]//256, out.shape[2], out.shape[3]) + return out + + """Indices shape should be in form B C H W + Pixels to fill should be marked with -1""" + @torch.no_grad() + def sample(self, ind_shape, ind): + # Generation loop (iterating through pixels across channels) + for h in range(ind_shape[2]): # Heights + for w in range(ind_shape[3]): # Widths + for c in range(ind_shape[1]): # Channels + # Skip if not to be filled (-1) + if (ind[:,c,h,w] != -1).all().item(): + continue + # Only have to input upper half of ind (rest are masked anyway) + pred = self.forward(ind[:,:,:h+1,:]) + probs = F.softmax(pred[:,:,c,h,w], dim=-1) + ind[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) + return ind + + +"""A general Masked convolution, with a the mask as a parameter.""" +class MaskedConvolution(nn.Module): + + def __init__(self, in_channels, out_channels, mask, dilation=1): + + super(MaskedConvolution, self).__init__() + kernel_size = (mask.shape[0], mask.shape[1]) + padding = ([dilation*(kernel_size[i] - 1) // 2 for i in range(2)]) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) + + # Mask as buffer (must be moved with devices) + self.register_buffer('mask', mask[None,None]) + + def forward(self, x): + self.conv.weight.data *= self.mask # Set all following weights to 0 (make sure it is in GPU) + return self.conv(x) + + +class VerticalConv(MaskedConvolution): + # Masks all pixels below + def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): + mask = torch.ones(kernel_size, kernel_size) + mask[kernel_size//2+1:,:] = 0 + # For the first convolution, mask center row + if mask_center: + mask[kernel_size//2,:] = 0 + + super().__init__(in_channels, out_channels, mask, dilation=dilation) + +class HorizontalConv(MaskedConvolution): + + def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): + # Mask out all pixels on the left. (Note that kernel has a size of 1 + # in height because we only look at the pixel in the same row) + mask = torch.ones(1,kernel_size) + mask[0,kernel_size//2+1:] = 0 + + # For first convolution, mask center pixel + if mask_center: + mask[0,kernel_size//2] = 0 + + super().__init__(in_channels, out_channels, mask, dilation=dilation) + +"""Gated Convolutions Model""" +class GatedMaskedConv(nn.Module): + + def __init__(self, in_channels, dilation=1): + + super(GatedMaskedConv, self).__init__() + self.conv_vert = VerticalConv(in_channels, out_channels=2*in_channels, dilation=dilation) + self.conv_horiz = HorizontalConv(in_channels, out_channels=2*in_channels, dilation=dilation) + self.conv_vert_to_horiz = nn.Conv2d(2*in_channels, 2*in_channels, kernel_size=1, padding=0) + self.conv_horiz_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) + + def forward(self, v_stack, h_stack): + # Vertical stack (left) + v_stack_feat = self.conv_vert(v_stack) + v_val, v_gate = v_stack_feat.chunk(2, dim=1) + v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate) + + # Horizontal stack (right) + h_stack_feat = self.conv_horiz(h_stack) + h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat) + h_val, h_gate = h_stack_feat.chunk(2, dim=1) + h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate) + h_stack_out = self.conv_horiz_1x1(h_stack_feat) + h_stack_out = h_stack_out + h_stack + + return v_stack_out, h_stack_out \ No newline at end of file diff --git a/Modules/VQ_VAE.py b/Modules/VQ_VAE.py new file mode 100644 index 000000000..a0a1ca3fc --- /dev/null +++ b/Modules/VQ_VAE.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + +# -------------------------------- +# VQVAE MODEL + +"""The VQ-VAE Model""" +class VQVAE(nn.Module): + + def __init__(self, num_embeddings, embedding_dim): + super(VQVAE, self).__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + self.encoder = Encoder() + self.quantiser = Quantiser(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.decoder = Decoder() + + def forward(self, x): + # Input shape is B, C, H, W + quant_input = self.encoder(x) + quant_out, quant_loss, encoding_indices = self.quantiser(quant_input) + output = self.decoder(quant_out) + + # Reconstruction Loss, and find the total loss + reconstruction_loss = F.mse_loss(x, output) + total_loss = quant_loss + reconstruction_loss + + return output, total_loss, encoding_indices + + """Function while allows output to be calculated directly from indices + param quant_out_shape is the shape that the quantiser is expected to return""" + @torch.no_grad() + def img_from_indices(self, indices, quant_out_shape): + quant_out = self.quantiser.output_from_indices(indices, quant_out_shape) # Output is currently 32*32 img with 32 channels + return self.decoder(quant_out) + +"""The Encoder Model used in VQ-VAE""" +class Encoder(nn.Module): + def __init__(self, ): + super(Encoder, self).__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + ) + + def forward(self, x): + out = self.encoder(x) + return out + +"""The VectorQuantiser Model used in VQ-VAE""" +class Quantiser(nn.Module): + def __init__(self, num_embeddings, embedding_dim) -> None: + super(Quantiser, self).__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.beta = 0.2 + + self.embedding = self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + + """Returns the encoding indices from the input""" + def get_encoding_indices(self, quant_input): + # Flatten + quant_input = quant_input.permute(0, 2, 3, 1) + quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) + + # Compute pairwise distances + dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) + + # Find index of nearest embedding + encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H + return encoding_indices + + """Returns the output from the encoding indices""" + def output_from_indices(self, indices, output_shape): + quant_out = torch.index_select(self.embedding.weight, 0, indices.view(-1)) + quant_out = quant_out.reshape(output_shape).permute(0, 3, 1, 2) + return quant_out + + def forward(self, quant_input): + # Finds the encoding indices + encoding_indices = self.get_encoding_indices(quant_input) + # Gets the output based on the encoding indices + quant_out = self.output_from_indices(encoding_indices, quant_input.shape) + + # Losses + commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) + codebook_loss = torch.mean((quant_out - quant_input.detach())**2) + loss = codebook_loss + self.beta*commitment_loss + + # Straight through gradient estimator for backprop + quant_out = quant_input + (quant_out - quant_input).detach() + + # Reshapes encoding indices to 'B, H, W' + encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) + + return quant_out, loss, encoding_indices + +"""The Decoder Model used in VQ-VAE""" +class Decoder(nn.Module): + def __init__(self, ) -> None: + super(Decoder, self).__init__() + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(1), + ) + + def forward(self, x): + out = self.decoder(x) + return out \ No newline at end of file diff --git a/VQ_VAE_46992925/VQ_VAE b/VQ_VAE_46992925/VQ_VAE_original similarity index 100% rename from VQ_VAE_46992925/VQ_VAE rename to VQ_VAE_46992925/VQ_VAE_original diff --git a/__pycache__/dataset.cpython-39.pyc b/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7aefde712ca7610679370bd9f9a63af85d22ddb GIT binary patch literal 2534 zcmcImON$#v5bo}IY9#HhZbE+jrUko*GaX!0j=%x~yxAYsqG@!2)l=PFU0q#|uN5|%9)Z@{ z|Ab$x6Y?7ljy3}ZH-V%NT||5m(TF9~r~1o$roTpN`X;T`EZ>GXBXLsKclDf^)Kbs) z=IeEzeL*5CvJXgP?=!yUZ2$MXFxR9;9-x`Q; zPxiuux8gJ$ps>hO*pZ;XA(%)u+fGKD7fg^uyN69z&JqYDfv5e1({MHHUntL>Gd5r| z`jnJbNdmJZQ>&yUo3V*K=MCVUId1~*7WUNnj`V2+mK?aHeTVGR8J)V39XXL3)n=e2 zs4b;DE#(|5FlcNn>;HwDI$k_h0;S zW6=EJr*Hl^`|FK1Q)bG;Oc}C>dK8?oXy|hj=p>MQ6S@H@$%Dp(eosoePk&^)<`#KQ z1bXs0Q>?SD%(&p`bE;~E2r~&;k?Pey<8F62>h~XpNzzHehqA-Nxaf>_J3Jb7!ckX# zdgbcXjjJ2oJ9%%E@~n{EjrEPotu3BzbZ_VB@bc%E*1KQxVcxwL76ljCtt?FTWGpYD zTo+SL@}e``Q{J^CmPHTtTyKG4FM!CneKZw1K@ewg5d@uUNe+oAuLHB#BhUkfc=bcf zjzb*&cdzJMSDgzA1e9szC(%T11-kc;S%vO@(SClRy}hQqWiX_2(3!d9%Dtc8=b6lf zs%=7;!XzGZNreUzx+O*!CNh zmdaHB#9Tw1XLIZQy53){i!e%e-VUjHx=9&d`8%Vwl9pDh%VLJw34%smgk)9b7RwC{i zoA&`I9k~q^mGHJD-U4}Ljfyyt%EX8%4^OG&LKTL5Y2wb1p@^hHBpU0$Pv!GKHvPA4V;>Vbw?1kp2OMiI~^R>4YZ zANU7fPHcqUsadU;HiBzmP3_VF2z4TJMgd^mX{~fCZpSNKaEtW}fVVWJb)2VU_kxx& zW)vJ{Z|Z9PM}eV*TR`$%Z6Sc#=?N|Al+556h`=@hR#foTrem`+%m#pxR+2yFLb*JP z|Y1cPX*S4$LSY3Ez5GXGQ(mWa^$TxyucN8Yoiue!?5}2t1k3C`vK1LPDV)~u{lKAqw z%su1U0_WOIL_kuuNf^k(or!}rUiY1^$=OthYtZ-#zqjb0@t{`8-KPvEMfjq`R r8TY=rTSZ_^tLW%D>|lNA+Zf=kc9$ZUqYq6$8)W?|&&ec+oI^%Et6#qj4X<_{S(hqhW+9G&`!Hpj5*eD7Boj0LnsG2c@1<>Yyx!OQ0-8b?mf=_N8zc?PX~%p?x`AL3<^iaVaz( z8f(uDKl`BHOM-SUUVZoNcVB&dHRy$_Z?&I9od+8mtKEJWb>cI;y|K|`IsI69BmOcf zkpZq44f8{zCER)HHLUFI8+POzyC?aZAF4*_L&G*29(FIKj%?a2mg0)6lkxnI@1aPH zM2(C?^VmEAi$|#Y*2qL@C(faPe>W+CYKH2FR4*wDwGtZG!rVAD(|T{v^`q`q6o&2I zcAPGZ>wehou2oYf9_&S0o7gS&e81gmC%&(F3hj_qdcD^KD|A|l@=E#bLC{OuF=*xe zakcbXl#dIu%Bhvwanxx|);d*b+4q~BAdY?i)cBwF=H}j@)!GU=o%K$z6|YA@J6Rv> zuSelvJs51pU;fJ7yLay1*?g_v9CV{z5^vtQb?2qkhf(*==7WBB@1?KazP0&IwAbH! zD@c+^_g?JrVpuQUyh%nB zkB165@c=)O$H+K_6i(CuxP1)1pQugiz&fxG9JI}mg&sGtj%~=qJuweTBX?wvN+YLb zKQvC&#+swcSV~t}ETBl$H&U+~M}DUt$7`PE5mWUftqNZujJ-jUGuiK>TFK}foi2$y zAXE$s+s!CWT@2xb%Ni2s#fcs?lXkx+i|b|XDaB=OQfl?}`+*LxqY(=fm=(38Dpt)L zE@UjpN6cs_xiRL*H8e)XE;&J7mj8TABLY3%Zf-rxoANv;*~; z?_<|a)+3stKaHhQC+X|vPFm{q{cRnD`WH!D$Q8I5bvm$<30IfUtI1k1B?>pD`gmro z;w`vJs7(-N9vV8I_;V;e4o3tx92m8uxoT5!R{pR~#SzpPp)!ja)W^iybqL&qCHy-_ z4ybMj{l$Oj#E^dT6Z_qB^W9?pNtx&)s8iZFweHDR;X3&*4%Ibn~Hr%;oO^esLYe6coUl zWa_T6Q&`z4u3}VbL6xVpL==SDMRdnv76bGzv7l~D)Kb5|<~+Tl zV!)MW%X+G|n5haVGmqE_DDkhMD0Gg>@z+W`82xq)9!$J-!_Itk!^u2z!xc>|K@+Q* zu(1J9cXgBubZ-VI$Sdx|f7xkPDW_doNVn!B?r!GF%p=pLYxBk`Wo>N#;IH`n@DGs&83q>V?jV7s=|~V2 zP8kNMKaZiQ(+t}%#Ftrj-;dhcI|+rk7X|na^me!1%Mn`BbW&S}e+tzp1r&F(vK=M< zd582S2IjtHN!_&8Er9cewcLoxlN)tYe4H z7?7p}q8uG>WmPk>>?ytF|Ha1E|d==t{EJIHUi7zQivet*bAX_ z@(-ShvVw|G#%xjT>FXNg{E(I|TQ#$vPI?Y0vJdlI%Z78q1+_4|*cnc56}&MBXG>)*xfn7D746|-(# z0lVwgl6BJ@UYamHM^^+|^U!tt60+dMI7`@hmD|zuJ+%z~?l}U6O%P@tYA~Pp3n)Gr zYsX%`q#37fXhQVGzWy94Sue5@uw@<&tt7zZvV_$tQCQ?fteT#>0b?P9-Nt@Pd5Y4Qr<@P!M-Txd*P#QPC# z;qmzVJXxhzB(Tjw_qZbKOggwj{Ff+Z)dtp$#4@2ff~yM6Bevm@XWFI2J1!@ck$I?C zQaz5$skU>jJ=MQm0-hLtHSu<9NMDKJ)(UO{7+W0K=vk1^vAprEHI-JoLHx*X0s)Yg z(zn5xd11Q~Fqg@6P1KCO@6^-`?DUsVtktvV@!KR2IwkQ;Ywt!nX*YuowJxn9F?#Gj zj)6#z()w@ob$iG=fa%U1>C8fi$~Ps2d=p+WU3?=*qVStITog{N)d@a`bgW;(hN&m3 z1KNwcrrPfL&7Gi!xHnE0ciLeXO$b?HJV(+09BailQAna{E~zC%tS(e;{QGaOI=oa! zdcMBX?GT_u-@k_noPmL>2=yr=G{Jo9nt}L4a<@F6g31Ssz*9-0I%md3(x31&592A4BzzY8WI+Bm_~i~{#j2z2N&-%Y z>bh~Lz{S5mu%r!6Nc(T$n%NI7N&5%jikZcJaL(4qZh@Df{m432C%92KG4wn8FM{uI z&e^D<9N~)TDk8=_Zp17|#7*p7XC%>N(UY(0a8=}exRa1Ty9Dz^BuUwUO0$>vvpB;& zWB; zm2`tAS=d$W4mwG@*Y7g1E5*q1qG43rfs_%GWj6q+g=(sHQuThSnyI~~qcC;$bbsr7 z9Dt-!j`4RmSWF2U7EA}mv#Lz<3+i*~MI>5dS!QG5r{;wMR0uhxex9`_EM8zSQJe|7 z6L9_vzarni;;kUxC$)*+%tJNh6LZ;dJb;`EH-|vfr)A__5Xe)bVNYU#^wI=$`kjHm zAQ7~g}IO0u{CoNmXT(KeQhw7-Sx%t#kNw8bIJ6vjS4{{iRA>&Tv>L?hDpPO{HGppqv|b+y90!5ZGU z__bz|26CMWAXr2l1}Yo7%0|jWRKIJOhR>TzAkc{_E-h^RhZvcsJYj(BnTG^d7D?tM z#+25}F}d|idZoola6-dkh*1TtoY<+*KUwT`h zC*~oTON+r=MnQIN9nMsOo^j!~u-&Ph;h){_MR69g@JOe5LPVv}f6PK6n|E0gg%Q;e zg&7-xD9k@&M5ZvRrfwoqSTTpsO$R+=6To%sD;S*dGgrre53=9$9c zpHoVk+-q{8i={-ulVio4Bk&e!ap=8C=KkDhuC^=yE02#^nrJGqb&evds4K*=3o{85oJ# zl@U@j%&P#BDlJBF#kGyFTVgj)^MJSn+!^#9<(QNNJ$&(q+uPlKJJjDtzy1>ze~JR4 z5!P*!u!Eve%Y$scg&V_$gTyN6G~CDjHZr9~wdjNcY}E27JRiA}joisc7M&P5Wg}x> z^E2!f^LrBRc5r9U@Vw4{WJDh9@4R6Czn+IG_U~Mp0)yxA#9-Ez(~|Fp{U*Ndd>$RC zC!c}Vd%Y|ryj(QbTZ5i_Uqkp>nbfoNWGNekV679c)2XIzmP*QZpzI4r1*^+NfW(U8 z_A None: + super(Quantiser, self).__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.beta = 0.2 + + self.embedding = self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + + """Returns the encoding indices from the input""" + def get_encoding_indices(self, quant_input): + # Flatten + quant_input = quant_input.permute(0, 2, 3, 1) + quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) + + # Compute pairwise distances + dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) + + # Find index of nearest embedding + encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H + return encoding_indices + + """Returns the output from the encoding indices""" + def output_from_indices(self, indices, output_shape): + quant_out = torch.index_select(self.embedding.weight, 0, indices.view(-1)) + quant_out = quant_out.reshape(output_shape).permute(0, 3, 1, 2) + return quant_out + + def forward(self, quant_input): + # Finds the encoding indices + encoding_indices = self.get_encoding_indices(quant_input) + # Gets the output based on the encoding indices + quant_out = self.output_from_indices(encoding_indices, quant_input.shape) + + # Losses + commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) + codebook_loss = torch.mean((quant_out - quant_input.detach())**2) + loss = codebook_loss + self.beta*commitment_loss + + # Straight through gradient estimator for backprop + quant_out = quant_input + (quant_out - quant_input).detach() + + # Reshapes encoding indices to 'B, H, W' + encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) + + return quant_out, loss, encoding_indices + +"""The Decoder Model used in VQ-VAE""" +class Decoder(nn.Module): + def __init__(self, ) -> None: + super(Decoder, self).__init__() + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(1), + ) + + def forward(self, x): + out = self.decoder(x) + return out + + + +# ---------------------------------------------------------------------------- +# PixelCNN model +# ---------------------------------------------------------------------------- + +"""Autoregressive PixelCNN model""" +class PixelCNN(nn.Module): + + def __init__(self, in_channels, hidden_channels, num_embeddings): + super(PixelCNN, self).__init__() + # Equal to the number of embeddings in the VQVAE + self.num_embeddings = num_embeddings + # Initial convolutions skipping the center pixel + self.conv_vstack = VerticalConv(in_channels, hidden_channels, mask_center=True) + self.conv_hstack = HorizontalConv(in_channels, hidden_channels, mask_center=True) + # Convolution block of PixelCNN. Uses dilation instead of downscaling + self.conv_layers = nn.ModuleList([ + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=2), + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=4), + GatedMaskedConv(hidden_channels), + GatedMaskedConv(hidden_channels, dilation=2), + GatedMaskedConv(hidden_channels) + ]) + # Output classification convolution (1x1) + # The output channels should be in_channels*number of embeddings to learn continuous space and calc. CrossEntropyLoss + self.conv_out = nn.Conv2d(hidden_channels, in_channels*self.num_embeddings, kernel_size=1, padding=0) + + + def forward(self, x): + # Scale input from 0 to 255 to -1 to 1 + x = (x.float() / 255.0) * 2 - 1 + + # Initial convolutions + v_stack = self.conv_vstack(x) + h_stack = self.conv_hstack(x) + # Gated Convolutions + for layer in self.conv_layers: + v_stack, h_stack = layer(v_stack, h_stack) + # 1x1 classification convolution + # Apply ELU (exponential activation function) before 1x1 convolution for non-linearity on residual connection + out = self.conv_out(F.elu(h_stack)) + + # Output dimensions: [Batch, Classes, Channels, Height, Width] (classes = num_embeddings) + out = out.reshape(out.shape[0], self.num_embeddings, out.shape[1]//256, out.shape[2], out.shape[3]) + return out + + """Indices shape should be in form B C H W + Pixels to fill should be marked with -1""" + @torch.no_grad() + def sample(self, ind_shape, ind): + # Generation loop (iterating through pixels across channels) + for h in range(ind_shape[2]): # Heights + for w in range(ind_shape[3]): # Widths + for c in range(ind_shape[1]): # Channels + # Skip if not to be filled (-1) + if (ind[:,c,h,w] != -1).all().item(): + continue + # Only have to input upper half of ind (rest are masked anyway) + pred = self.forward(ind[:,:,:h+1,:]) + probs = F.softmax(pred[:,:,c,h,w], dim=-1) + ind[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) + return ind + + +"""A general Masked convolution, with a the mask as a parameter.""" +class MaskedConvolution(nn.Module): + + def __init__(self, in_channels, out_channels, mask, dilation=1): + + super(MaskedConvolution, self).__init__() + kernel_size = (mask.shape[0], mask.shape[1]) + padding = ([dilation*(kernel_size[i] - 1) // 2 for i in range(2)]) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) + + # Mask as buffer (must be moved with devices) + self.register_buffer('mask', mask[None,None]) + + def forward(self, x): + self.conv.weight.data *= self.mask # Set all following weights to 0 (make sure it is in GPU) + return self.conv(x) + + +class VerticalConv(MaskedConvolution): + # Masks all pixels below + def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): + mask = torch.ones(kernel_size, kernel_size) + mask[kernel_size//2+1:,:] = 0 + # For the first convolution, mask center row + if mask_center: + mask[kernel_size//2,:] = 0 + + super().__init__(in_channels, out_channels, mask, dilation=dilation) + +class HorizontalConv(MaskedConvolution): + + def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): + # Mask out all pixels on the left. (Note that kernel has a size of 1 + # in height because we only look at the pixel in the same row) + mask = torch.ones(1,kernel_size) + mask[0,kernel_size//2+1:] = 0 + + # For first convolution, mask center pixel + if mask_center: + mask[0,kernel_size//2] = 0 + + super().__init__(in_channels, out_channels, mask, dilation=dilation) + +"""Gated Convolutions Model""" +class GatedMaskedConv(nn.Module): + + def __init__(self, in_channels, dilation=1): + + super(GatedMaskedConv, self).__init__() + self.conv_vert = VerticalConv(in_channels, out_channels=2*in_channels, dilation=dilation) + self.conv_horiz = HorizontalConv(in_channels, out_channels=2*in_channels, dilation=dilation) + self.conv_vert_to_horiz = nn.Conv2d(2*in_channels, 2*in_channels, kernel_size=1, padding=0) + self.conv_horiz_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) + + def forward(self, v_stack, h_stack): + # Vertical stack (left) + v_stack_feat = self.conv_vert(v_stack) + v_val, v_gate = v_stack_feat.chunk(2, dim=1) + v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate) + + # Horizontal stack (right) + h_stack_feat = self.conv_horiz(h_stack) + h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat) + h_val, h_gate = h_stack_feat.chunk(2, dim=1) + h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate) + h_stack_out = self.conv_horiz_1x1(h_stack_feat) + h_stack_out = h_stack_out + h_stack + + return v_stack_out, h_stack_out \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 000000000..e727e4ac5 --- /dev/null +++ b/train.py @@ -0,0 +1,300 @@ +"""Training of the VQVAE and PixelCNN""" + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import numpy as np +from tqdm.auto import tqdm +from PIL import Image +import torch.utils.data +from torchvision import datasets, transforms, utils +import matplotlib.pyplot as plt + +import modules +import dataset + +# Replace with preferred device and local path(s) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print("Torch version ", torch.__version__) +path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" +vqvae_save_path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/Saved_Models/" +pixelCNN_save_path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/Saved_Models/" + + +# Hyperparameters +batch_size = 128 +vqvae_num_epochs = 60 +vqvae_lr = 1e-3 +cnn_num_epochs = 25 +cnn_lr = 1e-3 + +# Data (If necessary, replace with the local names of the train, validate and test folders) +print("> Loading Data") +processed_data = dataset.DataPreparer(path, "keras_png_slices_train/", "keras_png_slices_validate/", "keras_png_slices_test/", batch_size) +train_dataloader = processed_data.train_dataloader +validate_dataloader = processed_data.validate_dataloader +test_dataloader = processed_data.test_dataloader + +# Models +vqvae_model = modules.VQVAE(num_embeddings=256, embedding_dim=32).to(device) +cnn_model = modules.PixelCNN(in_channels=1, hidden_channels=128, num_embeddings=256).to(device) + +# Optimisers +vqvae_optimiser = torch.optim.Adam(vqvae_model.parameters(), vqvae_lr) +cnn_optimiser = torch.optim.Adam(cnn_model.parameters(), cnn_lr) + +# Initialise losses (for graphing) +vqvae_training_loss = [] +vqvae_validation_loss = [] +cnn_training_loss = [] +cnn_validation_loss = [] + + + +# -------------------------------------------------------- +# VQVAE functions +# -------------------------------------------------------- + +def train_vqvae(): + + print("> Training VQVAE") + + for epoch_num, epoch in enumerate(range(vqvae_num_epochs)): + + # Train + vqvae_model.train() + for train_batch in train_dataloader: + images = train_batch.to(device, dtype=torch.float32) + + output, quant_loss, reconstruction_loss, _ = vqvae_model(images) + training_loss = quant_loss + reconstruction_loss # Can be adjusted if necessary + + vqvae_optimiser.zero_grad() # Reset gradients to zero + training_loss.backward() # Calculate grad + vqvae_optimiser.step() # Adjust weights + + with torch.no_grad(): + vqvae_training_loss.append((quant_loss.cpu(), reconstruction_loss.cpu(), training_loss.cpu())) + + # Evaluate + vqvae_model.eval() + for validate_batch in (validate_dataloader): + images = validate_batch.to(device, dtype=torch.float32) + + with torch.no_grad(): + output, quant_loss, reconstruction_loss, _ = vqvae_model(images) + validation_loss = quant_loss + reconstruction_loss + vqvae_validation_loss.append((quant_loss.cpu(), reconstruction_loss.cpu(), validation_loss.cpu())) + + print("Epoch {} of {}. Training Loss: {}, Validation Loss: {}".format(epoch_num+1, vqvae_num_epochs, training_loss, validation_loss)) + + +def plot_vqvae_losses(show_individual_losses=False): + # Losses are in the order (Quant, Reconstruction, Total) + plt.title("VQVAE Losses") + plt.xlabel("Epoch") + plt.ylabel("Loss") + + if show_individual_losses == False: + plt.plot([loss[2] for loss in vqvae_training_loss], color='blue') + plt.plot([loss[2] for loss in vqvae_validation_loss], color='red') + plt.legend(["Training Loss", "Validation Loss"]) + else: + plt.plot([loss[0] for loss in vqvae_training_loss], color='blue', ls='--') + plt.plot([loss[0] for loss in vqvae_validation_loss], color='red', ls='--') + plt.plot([loss[1] for loss in vqvae_training_loss], color='blue') + plt.plot([loss[1] for loss in vqvae_validation_loss], color='red') + plt.legend(["Training Quantisation Loss", "Validation Quantisation Loss", "Training Reconstruction Loss", "Validation Reconstruction Loss"]) + plt.show() + + +"""Function to test the VQVAE. Input the number of samples to show""" +def test_vqvae(num_shown=0): + + print("> Testing VQVAE") + + # Calculate losses + vqvae_model.eval() + test_losses = [] + for test_batch in (test_dataloader): + images = test_batch.to(device, dtype=torch.float32) + with torch.no_grad(): + output, quant_loss, r_loss, _ = vqvae_model(images) + # For averaging loss + test_losses.append((quant_loss + r_loss).cpu()) + + print("Average loss during testing is: ", np.mean(np.array(test_losses))) + + # Show N Reconstructed Images. + if (num_shown != 0): + input_imgs = processed_data.test_dataset[0:num_shown] + input_imgs = input_imgs.to(device, dtype=torch.float32) + with torch.no_grad(): + output_imgs, _, _, encoding_indices = vqvae_model(input_imgs) + + fig, ax = plt.subplots(num_shown, 3) + plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0) + ax[0, 0].set_title("Input Image") + ax[0, 1].set_title("CodeBook Indices") + ax[0, 2].set_title("Reconstructed Image") + for i in range(num_shown): + for j in range(3): + ax[i, j].axis('off') + ax[i, 0].imshow(input_imgs[i][0].cpu().numpy(), cmap='gray') + ax[i, 1].imshow(encoding_indices[i].cpu().numpy(), cmap='gray') + ax[i, 2].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray') + + plt.show() + + +# Code +#train_vqvae() +#plot_vqvae_losses() +#test_vqvae(num_shown=3) +#print("> Saving Model") +#torch.save(vqvae_model, vqvae_save_path + "trained_vqvae.pth") + + + +# -------------------------------------------------------- +# PixCNN functions +# -------------------------------------------------------- +vqvae_model = torch.load(vqvae_save_path + "trained_vqvae.pth") +encoder = vqvae_model.__getattr__("encoder") +quantiser = vqvae_model.__getattr__("quantiser") +decoder = vqvae_model.__getattr__("decoder") + + +def train_pixcnn(): + + print("> Training PixelCNN") + + for epoch_num, epoch in enumerate(range(cnn_num_epochs)): + + cnn_model.train() + + for train_batch in train_dataloader: + + # Get the quantised outputs + with torch.no_grad(): + encoder_output = encoder(train_batch.to(device)) + _, _, indices = quantiser(encoder_output) + indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) + + output = cnn_model(indices) + + # Compute loss + nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood + bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension + training_loss = bpd.mean() + + cnn_optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) + training_loss.backward() # Calculate grad + cnn_optimiser.step() # Adjust weights + + with torch.no_grad(): + cnn_training_loss.append(training_loss.cpu()) + + cnn_model.eval() + + for validate_batch in validate_dataloader: + with torch.no_grad(): + # Get the quantised outputs + encoder_output = encoder(validate_batch.to(device)) + _, _, indices = quantiser(encoder_output) + indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) + output = cnn_model(indices) + + # Compute loss + nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood + bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension + validation_loss = bpd.mean() + + with torch.no_grad(): + cnn_validation_loss.append(validation_loss.cpu()) + + print("Epoch {} of {}. Training Loss: {}, Validation Loss: {}".format(epoch_num+1, cnn_num_epochs, training_loss, validation_loss)) + + +def plot_cnn_loss(): + # Losses are in the order (Quant, Reconstruction, Total) + plt.title("PixelCNN Losses") + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.plot(cnn_training_loss, color='blue') + plt.plot(cnn_validation_loss, color='red') + plt.legend(["Training Loss", "Validation Loss"]) + plt.show() + + +def test_cnn(shown_imgs=0): + print("> Testing PixelCNN") + + test_loss = [] + + cnn_model.eval() + + with torch.no_grad(): + for test_batch in test_dataloader: + # Get the quantised outputs + encoder_output = encoder(test_batch.to(device)) + _, _, indices = quantiser(encoder_output) + indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) + output = cnn_model(indices) + + # Compute loss + nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood + bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension + validation_loss = bpd.mean() + test_loss.append(validation_loss.cpu()) + + print("Average loss during testing is: ", np.mean(np.array(test_loss))) + + if shown_imgs != 0: + print(" > Showing Images") + + # Inputs + test_batch = processed_data.test_dataset[0:shown_imgs] + + encoder_output = encoder(test_batch.to(device)) + _, _, indices = quantiser(encoder_output) + + indices_shape = indices.cpu().numpy().shape + + print("Indices shape is: ", indices.cpu().numpy().shape) + indices = indices.reshape((indices_shape[0], 1,indices_shape[1], indices_shape[2])) + print("Indices shape is: ", indices.cpu().numpy().shape) + + # Masked Inputs (only top quarter shown) + masked_indices = 1*indices + masked_indices[:,:,16:,:] = -1 + + gen_indices = cnn_model.sample((shown_imgs, 1, 32, 32), ind=masked_indices*1) + + fig, ax = plt.subplots(shown_imgs, 3) + + for a in ax.flatten(): + a.axis('off') + + ax[0, 0].set_title("Real") + ax[0, 1].set_title("Masked") + ax[0, 2].set_title("Generated") + + for i in range(shown_imgs): + ax[i, 0].imshow(indices[i][0].long().cpu().numpy(), cmap='gray') + ax[i, 1].imshow(masked_indices[i][0].cpu().numpy(), cmap='gray') + ax[i, 2].imshow(gen_indices[i][0].cpu().numpy(), cmap='gray') + plt.show() + + + + plt.imshow(vqvae_model.img_from_indices(gen_indices[0], (1, 32, 32, 32))[0][0].cpu().numpy(), cmap='gray') + plt.show() + +train_pixcnn() +plot_cnn_loss() +test_cnn(shown_imgs=3) +print("Saving pixel cnn") +torch.save(cnn_model, pixelCNN_save_path + "PixelCNN model.pth") \ No newline at end of file From bbe1f49cc969023d2e3940a3b0e3b57c6476d59a Mon Sep 17 00:00:00 2001 From: DruCallaghan Date: Fri, 20 Oct 2023 18:03:06 +1000 Subject: [PATCH 19/26] Final Submission --- dataset.py | 6 +-- modules.py | 50 ++++++++++++++-------- predict | 122 +++++++++++++++++++++++++++++++++++++++++++++++++++++ train.py | 56 +++++++++++++----------- 4 files changed, 190 insertions(+), 44 deletions(-) create mode 100644 predict diff --git a/dataset.py b/dataset.py index 1e9b21b1e..70c2e219a 100644 --- a/dataset.py +++ b/dataset.py @@ -63,8 +63,8 @@ def load_data_from_folder(self, path, name): data.append(image) - #if i == 5: - # return data - #i += 1 + if i == 50: + return data + i += 1 return np.array(data) \ No newline at end of file diff --git a/modules.py b/modules.py index 3a72a0242..7707f4838 100644 --- a/modules.py +++ b/modules.py @@ -15,12 +15,14 @@ class VQVAE(nn.Module): def __init__(self, num_embeddings, embedding_dim): super(VQVAE, self).__init__() + self.epochs_trained = 0 + self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim - self.encoder = Encoder() + self.encoder = Encoder(embedding_dim) self.quantiser = Quantiser(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - self.decoder = Decoder() + self.decoder = Decoder(embedding_dim) def forward(self, x): # Input shape is B, C, H, W @@ -36,25 +38,25 @@ def forward(self, x): """Function while allows output to be calculated directly from indices param quant_out_shape is the shape that the quantiser is expected to return""" @torch.no_grad() - def img_from_indices(self, indices, quant_out_shape): - quant_out = self.quantiser.output_from_indices(indices, quant_out_shape) # Output is currently 32*32 img with 32 channels + def img_from_indices(self, indices, quant_out_shape_BHWC): + quant_out = self.quantiser.output_from_indices(indices, quant_out_shape_BHWC) # Output is currently 32*32 img with 32 channels return self.decoder(quant_out) """The Encoder Model used in VQ-VAE""" class Encoder(nn.Module): - def __init__(self, ): + def __init__(self, embedding_dim): super(Encoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(), - nn.Conv2d(16, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(), + nn.Conv2d(32, embedding_dim, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(embedding_dim), + nn.ReLU(), ) def forward(self, x): @@ -85,17 +87,26 @@ def get_encoding_indices(self, quant_input): encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H return encoding_indices - """Returns the output from the encoding indices""" - def output_from_indices(self, indices, output_shape): + """Returns the output from the encoding indices without calculating losses etc.""" + def output_from_indices(self, indices, output_shape_BHWC): quant_out = torch.index_select(self.embedding.weight, 0, indices.view(-1)) - quant_out = quant_out.reshape(output_shape).permute(0, 3, 1, 2) + quant_out = quant_out.reshape(output_shape_BHWC).permute(0, 3, 1, 2) return quant_out def forward(self, quant_input): + + B, C, H, W = quant_input.shape + # Finds the encoding indices encoding_indices = self.get_encoding_indices(quant_input) + + quant_input = quant_input.permute(0, 2, 3, 1) + quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) + # Gets the output based on the encoding indices - quant_out = self.output_from_indices(encoding_indices, quant_input.shape) + quant_out = torch.index_select(self.embedding.weight, 0, encoding_indices.view(-1)) + + quant_input = quant_input.reshape((-1, quant_input.size(-1))) # Losses commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) @@ -104,7 +115,9 @@ def forward(self, quant_input): # Straight through gradient estimator for backprop quant_out = quant_input + (quant_out - quant_input).detach() - + + # Reshape quant_out + quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) # Reshapes encoding indices to 'B, H, W' encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) @@ -112,14 +125,14 @@ def forward(self, quant_input): """The Decoder Model used in VQ-VAE""" class Decoder(nn.Module): - def __init__(self, ) -> None: + def __init__(self, embedding_dim) -> None: super(Decoder, self).__init__() self.decoder = nn.Sequential( - nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), + nn.ConvTranspose2d(embedding_dim, 32, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(32), nn.ReLU(), - nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), + nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), @@ -141,6 +154,9 @@ class PixelCNN(nn.Module): def __init__(self, in_channels, hidden_channels, num_embeddings): super(PixelCNN, self).__init__() + + self.epochs_trained = 0 + # Equal to the number of embeddings in the VQVAE self.num_embeddings = num_embeddings # Initial convolutions skipping the center pixel diff --git a/predict b/predict new file mode 100644 index 000000000..47b9e0c96 --- /dev/null +++ b/predict @@ -0,0 +1,122 @@ +"""Example Predictions""" + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import numpy as np +from tqdm.auto import tqdm +from PIL import Image +import torch.utils.data +from torchvision import datasets, transforms, utils +import matplotlib.pyplot as plt + +import modules +import dataset + +# Replace with preferred device and local path(s) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print("Torch version ", torch.__version__) +path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +vqvae_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/" +pixelCNN_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/" + +# Load the models +vqvae_model = torch.load(vqvae_save_path + "trained_vqvae.pth", map_location=device) +cnn_model = torch.load(pixelCNN_save_path + "PixelCNN model.pth", map_location=device) + +encoder = vqvae_model.__getattr__("encoder") +quantiser = vqvae_model.__getattr__("quantiser") +decoder = vqvae_model.__getattr__("decoder") + +# Update this parameter if using a newer model +embedding_dim = 32 + +# Load data +processed_data = dataset.DataPreparer(path, "keras_png_slices_train/", "keras_png_slices_validate/", "keras_png_slices_test/", batch_size=128) + +# VQVAE outputs +def show_reconstructed_imgs(num_shown=2): + input_imgs = processed_data.test_dataset[0:num_shown] + input_imgs = input_imgs.to(device, dtype=torch.float32) + with torch.no_grad(): + output_imgs, _, _, encoding_indices = vqvae_model(input_imgs) + + fig, ax = plt.subplots(num_shown, 3) + plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0) + ax[0, 0].set_title("Input Image") + ax[0, 1].set_title("CodeBook Indices") + ax[0, 2].set_title("Reconstructed Image") + for i in range(num_shown): + for j in range(3): + ax[i, j].axis('off') + ax[i, 0].imshow(input_imgs[i][0].cpu().numpy(), cmap='gray') + ax[i, 1].imshow(encoding_indices[i].cpu().numpy(), cmap='gray') + ax[i, 2].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray') + + plt.show() + +# PixelCNN Indices Generation +def show_generated_indices(shown_imgs=2): + + print(" > Showing Images") + + # Inputs + test_batch = processed_data.test_dataset[0:shown_imgs] + + encoder_output = encoder(test_batch.to(device)) + _, _, indices = quantiser(encoder_output) + + indices_shape = indices.cpu().numpy().shape + + #print("Indices shape is: ", indices.cpu().numpy().shape) + indices = indices.reshape((indices_shape[0], 1,indices_shape[1], indices_shape[2])) + #print("Indices shape is: ", indices.cpu().numpy().shape) + + # Masked Inputs (only top half shown to model) + masked_indices = 1*indices + masked_indices[:,:,16:,:] = -1 + + gen_indices = cnn_model.sample((shown_imgs, 1, 32, 32), ind=masked_indices*1) + + fig, ax = plt.subplots(shown_imgs, 3) + + for a in ax.flatten(): + a.axis('off') + + ax[0, 0].set_title("Real") + ax[0, 1].set_title("Masked") + ax[0, 2].set_title("Generated") + + for i in range(shown_imgs): + ax[i, 0].imshow(indices[i][0].long().cpu().numpy(), cmap='gray') + ax[i, 1].imshow(masked_indices[i][0].cpu().numpy(), cmap='gray') + ax[i, 2].imshow(gen_indices[i][0].cpu().numpy(), cmap='gray') + plt.show() + +def show_generated_output(): + # Inputs + test_batch = processed_data.test_dataset[0:1] + + encoder_output = encoder(test_batch.to(device)) + _, _, indices = quantiser(encoder_output) + + indices_shape = indices.cpu().numpy().shape + + indices = indices.reshape((indices_shape[0], 1,indices_shape[1], indices_shape[2])) + + # Masked Inputs (only top half shown to model) + masked_indices = 1*indices + masked_indices[:,:,:,:] = -1 + + + gen_indices = cnn_model.sample((1, 1, 32, 32), ind=masked_indices*1) + + # Change the last 32 using a new model with embedding_dim > 32 + plt.imshow(vqvae_model.img_from_indices(gen_indices[0], (1, 32, 32, embedding_dim))[0][0].cpu().numpy(), cmap='gray') + plt.show() + +show_reconstructed_imgs() +show_generated_indices() +show_generated_output() diff --git a/train.py b/train.py index e727e4ac5..754c83437 100644 --- a/train.py +++ b/train.py @@ -18,17 +18,17 @@ # Replace with preferred device and local path(s) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Torch version ", torch.__version__) -path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" -vqvae_save_path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/Saved_Models/" -pixelCNN_save_path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/Saved_Models/" +path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" +vqvae_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/" +pixelCNN_save_path = "C:/Users/61423/COMP3710/COMP3710 Report/Models/" # Hyperparameters batch_size = 128 -vqvae_num_epochs = 60 -vqvae_lr = 1e-3 -cnn_num_epochs = 25 -cnn_lr = 1e-3 +vqvae_num_epochs = 10 +vqvae_lr = 5e-4 +cnn_num_epochs = 10 +cnn_lr = 5e-4 # Data (If necessary, replace with the local names of the train, validate and test folders) print("> Loading Data") @@ -37,9 +37,12 @@ validate_dataloader = processed_data.validate_dataloader test_dataloader = processed_data.test_dataloader -# Models -vqvae_model = modules.VQVAE(num_embeddings=256, embedding_dim=32).to(device) +# Models (if not already saved) +vqvae_model = modules.VQVAE(num_embeddings=256, embedding_dim=64).to(device) cnn_model = modules.PixelCNN(in_channels=1, hidden_channels=128, num_embeddings=256).to(device) +# Models (if previously saved). +#vqvae_model = torch.load(vqvae_save_path + "trained_vqvae_n.pth", map_location=device) +#cnn_model = torch.load(pixelCNN_save_path + "PixelCNN model_n.pth", map_location=device) # Optimisers vqvae_optimiser = torch.optim.Adam(vqvae_model.parameters(), vqvae_lr) @@ -52,7 +55,6 @@ cnn_validation_loss = [] - # -------------------------------------------------------- # VQVAE functions # -------------------------------------------------------- @@ -76,7 +78,7 @@ def train_vqvae(): vqvae_optimiser.step() # Adjust weights with torch.no_grad(): - vqvae_training_loss.append((quant_loss.cpu(), reconstruction_loss.cpu(), training_loss.cpu())) + vqvae_training_loss.append((quant_loss.detach().cpu(), reconstruction_loss.detach().cpu(), training_loss.detach().cpu())) # Evaluate vqvae_model.eval() @@ -86,10 +88,17 @@ def train_vqvae(): with torch.no_grad(): output, quant_loss, reconstruction_loss, _ = vqvae_model(images) validation_loss = quant_loss + reconstruction_loss - vqvae_validation_loss.append((quant_loss.cpu(), reconstruction_loss.cpu(), validation_loss.cpu())) + + with torch.no_grad(): + vqvae_validation_loss.append((quant_loss.detach().cpu(), reconstruction_loss.detach().cpu(), validation_loss.detach().cpu())) + vqvae_model.epochs_trained += 1 print("Epoch {} of {}. Training Loss: {}, Validation Loss: {}".format(epoch_num+1, vqvae_num_epochs, training_loss, validation_loss)) + if (epoch_num+1) % 20 == 0: + print("Saving VQVAE model") + torch.save(vqvae_model, vqvae_save_path + "trained_vqvae_{}.pth".format(vqvae_model.epochs_trained)) + def plot_vqvae_losses(show_individual_losses=False): # Losses are in the order (Quant, Reconstruction, Total) @@ -149,19 +158,16 @@ def test_vqvae(num_shown=0): plt.show() -# Code -#train_vqvae() -#plot_vqvae_losses() -#test_vqvae(num_shown=3) -#print("> Saving Model") -#torch.save(vqvae_model, vqvae_save_path + "trained_vqvae.pth") - +# Uncomment any functions to call +train_vqvae() +plot_vqvae_losses() +test_vqvae(num_shown=3) # -------------------------------------------------------- # PixCNN functions # -------------------------------------------------------- -vqvae_model = torch.load(vqvae_save_path + "trained_vqvae.pth") + encoder = vqvae_model.__getattr__("encoder") quantiser = vqvae_model.__getattr__("quantiser") decoder = vqvae_model.__getattr__("decoder") @@ -215,8 +221,13 @@ def train_pixcnn(): with torch.no_grad(): cnn_validation_loss.append(validation_loss.cpu()) + cnn_model.epochs_trained += 1 print("Epoch {} of {}. Training Loss: {}, Validation Loss: {}".format(epoch_num+1, cnn_num_epochs, training_loss, validation_loss)) + if (epoch_num+1) % 20 == 0: + print("Saving Pixel CNN") + torch.save(cnn_model, pixelCNN_save_path + "PixelCNN model_{}.pth".format(cnn_model.epochs_trained)) + def plot_cnn_loss(): # Losses are in the order (Quant, Reconstruction, Total) @@ -288,13 +299,10 @@ def test_cnn(shown_imgs=0): ax[i, 2].imshow(gen_indices[i][0].cpu().numpy(), cmap='gray') plt.show() - - plt.imshow(vqvae_model.img_from_indices(gen_indices[0], (1, 32, 32, 32))[0][0].cpu().numpy(), cmap='gray') plt.show() +# Uncomment any functions to call train_pixcnn() plot_cnn_loss() test_cnn(shown_imgs=3) -print("Saving pixel cnn") -torch.save(cnn_model, pixelCNN_save_path + "PixelCNN model.pth") \ No newline at end of file From 597c6ffda211833215ac6efb12274270a9d2f77c Mon Sep 17 00:00:00 2001 From: DruCallaghan <141201090+DruCallaghan@users.noreply.github.com> Date: Fri, 20 Oct 2023 18:07:37 +1000 Subject: [PATCH 20/26] Update README.md --- README.md | 102 ++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 91 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 4a064f841..f8bb97de1 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,95 @@ -# Pattern Analysis -Pattern Analysis of various datasets by COMP3710 students at the University of Queensland. +Introduction +In this report, a VQVAE will be developed in Pytorch to generate new images using the OASIS dataset (grayscale MRI images of the brain). +A VQVAE (Vector-Quantised Variational Autoencoder) is a neural network architecture, which learns discrete latent space representations of the input. It reconstructs the original input from these representations. Discretising the latent space gives VQVAEs many advantages over the conventional VAEs, making the representations more compact and interpretable. Using an autoregressive neural network to learn the space of discrete representations, new plausible outputs are constructed by sampling from this learned space. +Hyperparameters: +The hyperparameters were tuned throughout the report. The learning rate which yielded the best outcome for the VQVAE was found to be 5e-4, and the batch size, 128. Scheduling was considered however wasn’t necessary for the models. +The number of embeddings was chosen to be 256, and the initial embedding dimension as 32. The embedding dimension, however, was increased, as the low embedding dimension was found to be insufficient in capturing the fine details of the dataset, bottlenecking the model. This caused the loss plateauing at roughly 0.3, with the model not yet achieving the desired sharpness due to this limiting factor. +Results: +The results of the report are shown below, the model generates plausible images, albeit extremely blurry, with much room for improvement. After 60 epochs, the VQVAE was able to produce accurate, albeit blurry, reconstructions. A graph of the losses is shown below. The individual losses (codebook loss, commitment loss and reconstruction loss) were checked numerous times to ensure expected behaviour. +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/fac1f5a9-42ce-4461-8418-52253bac6bf6) -We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. +This VQVAE produced the outputs shown below: +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/2821e6e7-ee87-4b0e-9ae9-955b864f8df3) -This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. + +These results demonstrated that the model was not overfitting and should be trained for more epochs. The loss, however, plateaued at roughly 0.3, which was an indication that the model was not complex enough to capture the detail of the dataset. The Pixel CNN was then trained for 25 epochs, with the results as shown below (Top: generated indices, Left: bottom half of the brain generated, Right: entire image generated): +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/a4f2a8d4-430a-4f4e-bd5d-5ce9487003c1) +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/944c35c0-4cee-4d98-90fa-32a8bcfb5a4b) +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/37bd9bd0-6869-4af4-8982-6a1405dc295a) -The library includes the following implemented in Tensorflow: -* fractals -* recognition problems + +These results demonstrated the model working as intended, even after just 60 epochs and 25 epochs respectively. The loss graph for the PixelCNN is shown below (Note the x-axis for the validation loss represents the total number of samples, not the epochs) +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/727e40c9-7079-41d8-8b65-11f1dac486ba) + + +Overfitting was also avoided here, as can be seen from the graph, however the PixelCNN was also limited due to the VQVAE (which at this point had only trained for 60 epochs). +This led to the number of hidden channels in the encoder and decoder being increased slightly, as well as the embedding dimensions (if necessary, the number of embeddings can also be increased). +The new model performed far better than the old, (however did not save correctly due to failure of the technology used), and was able to complete the task. The model provided is able generate far clearer and more accurate image after approximately 80 and 50 epochs (without overfitting). +If even more accurate results are required, the overall number of parameters (particularly embedding dimension, number of embeddings), can be increased further. +How to Use: +In order to run this file, the environment must have to following installed: +- Pytorch +- TQDM +- Numpy +- PIL +- Matplotlib +Create a new conda environment with python 3.8 then run the command ‘pip install library’ for each of the libraries above in the command terminal to install any missing libraries. The OASIS dataset will also need to be downloaded and placed in a directory with three folders – one containing the train set, the validation set, and test set respectively. This report contains four main files: +- Modules: Containing the VQVAE and PixelCNN models +- Dataset: Classes for loading and preprocessing data +- Train: Functions for training and saving the models, as well as plotting the losses. +- Predict: Showing example usage of the trained models. +The ‘train’ and ‘predict’ files in particular have config specific variables (such as paths, embedding_dim etc.) which must be changed at the top of the file. +Any changes to the model and data can be made in Modules and Dataset respectively. To train the models, use the train.py file. Replace the current path variable with the local path to the OASIS dataset, and the names of the folders in the ‘DataPreparer’ instance, as well as the path where the models and losses will be saved. (Note that the training functions save the models as a new file every 20 epochs). The test functions in train.py also have optional parameters to visualise the data. +New images can be generated from the predict.py file. The function show_reconstructed_imgs(num_shown=2) shows the VQVAE input images, the codebook indices, and the reconstructions. +show_generated_indices(shown_imgs=2) shows the codebook indices generated by the PixelCNN. This calls cnn_model.sample() which takes in codebook indices, and generates any which have been replaced with ‘-1’. The line ‘masked_indices[:,:,16:,:] = -1’ can be changed to alter which indices are generated. +show_generated_output() shows a new generated image. Again, the masked_indices variable can be changed, so that parts of images can be generated based on existing pixels. +Data Processing: +In the dataset class, the data is loaded into datasets. The transforms applied to all data is conversion to a tensor and normalisation, based on the mean and standard deviation of the entire dataset. The number of images in the train, validation, and test dataset was of the approximate ratio 20:2:1. +Models Overview +Note – the VQVAE model described contains 32 embedding dimensions. This was the original model, but this has since been increased. +VQVAE +The VQVAE model can be broken up into the encoder, vector quantiser, and decoder. Throughout the process, the data is generally in the shape (Batch, Channel, Height, Width). The batch size used is 128, as this was found to yield the best results overall. A visual depiction and descriptions of the VQ-VAE components are below. + ![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/0203148a-b794-4d82-862f-25fcaaa06baa) + +(Image from Neural Discrete Representation Learning2, 2018) +Encoder: +(B=128, C=1, H=256, W=256) -> (B=128, C=32, H=32, W=32) +The encoder used is a convolutional neural network which maps the grayscale (single channel) input images into a continuous latent space with 32 channels. Three convolutions are performed, and after each convolution, the data is normalised, and a non-linear activation function is applied. Each convolution has a stride of 2, and height and width in the latent space is reduced by a factor of 8. +In the current model used, the encoder changes the shape of the data from (128, 1, 256, 256) to (128, 32, 32, 32). To modify any layers of the model, changes can be made to the self.encoder attribute. (Note that the channels of the encoder and decoder should correspond to the embedding dimension of the vector quantiser). +Vector Quantiser: +(B=128, C=32, H=32, W=32) -> (B=128, C=32, H=32, W=32) +The encoder contains an embedded space/codebook (which it learns). For each pixel, a vector of the 32 channels associated with it is compared to each of the 32-dimensional embeddings in the embedded space (codebook). The index of the embedding which is closest (based on Euclidean distance) to the given vector is returned. The embedding vector which corresponds to this index is then found, and thus, the quantiser maps the 32-channel continuous space to a discrete space. +This process is broken into a number of functions, increasing the readability of the code and allowing for these functions to be called individually when generating images. The functions include: +- get_encoding_indices, which takes the original input (B, C, H, W), and saves the Euclidean distances to each embedding. It then finds the indices to the closest embedding, returning the codebook indices in the shape (B, H*W) +- output_from_indices, which takes the codebook indices of shape (B, H*W), and returns the corresponding embedding vector. It also reshapes the output into the correct shape (B, C, H, W) before returning it. +The forward function calls both these functions and contains a straight-through gradient estimator (shown by the red arrow in the diagram above). This is because the discretisation is non-differentiable, and therefore back-propagation could otherwise not occur. If x is the input to the quantiser and Z(x), the output. The output x + Z(x) - x is still Z(x), however detaching Z(x) and -x in ensures that the backpropagation can follow the dotted line shown below, avoiding the non-differentiable Z(x) block. + ![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/cb007ed1-244d-4931-b0e5-ab9a03007541) + +The forward function also returns the codebook indices (reshaped back to B, H, W), as well as the associated loss – comprised codebook loss and the commitment loss. The codebook corresponds to the distance between the latent vectors produced by the encoder and the closest vectors in the codebook, which ensures that the vectors in the embedded space are effective representations. The commitment loss prevents the codebook vectors from oscillating between values, encouraging them to converge. The vector quantiser has a ‘beta’ attribute (currently 0.2), and the quaniser loss is calculated as quant_loss = codebook_loss + beta*commitment_loss. The ‘beta’ value can be changed to affect the robustness and aggression of the quantiser. +Decoder: +(B=128, C=32, H=32, W=32) -> (B=128, C=1, H=256, W=256) +The decoder takes the output of the vector quantiser B, C, H, W (currently B, 32, 32, 32), and performs three transposed convolutions to attempt to reconstruct the original image. Each transposed convolution is again followed by normalisation and a non-linear activation function. + +The forward call of the VQVAE passes the image through the encoder, vector quantiser, then decoder. It calculates the reconstruction loss, which represents the dissimilarity between the input image and the output image. It returns the output image, vector quantiser loss, reconstruction loss, and the encoding indices (helpful for visualising). The VQVAE model also contains the function img_from_indices. This function is useful as it takes in the codebook indices (B, H, W) and allows for the constructed image to be returned directly, which is used during image generation. +PixelCNN +The autoregressive PixelCNN used is based on an implementation for MNIST data by Phillip Lippe1. +The pixel CNN attempts to learn the space of discrete codebook indices of the VQVAE, such that new plausible combinations can be predicted. This is achieved through a masked convolution, which does not allow the neural network to see the current or future pixels. This was achieved by creating a mask as shown below and multiplying with the kernel weight data before convolution. +image + + +However, using this kernel directly led to a blind spot in the receptive field, and therefore a gated convolution will be used, with both a vertical stack and horizontal stack, with architecture: +![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/99deeeff-ca6d-48b7-b379-9a27b8f27944) + +(Figure by Aaron van den Oord et al.): +The neural network then learns to predict future indices. A number of classes were used in developing the PixelCNN: +- MaskedConvolution: This is the base class which takes a mask as a parameter and performs a masked convolution. +- VerticalConv: This extends MaskedConvolution and is a vertical stack convolution which considers all pixels above the current pixel. +- HorizontalConv: This extends MaskedConvolution and is a horizontal stack convolution which considers all pixels to the left of the current pixel. It also contains residual blocks, because the output of the horizontal convolution is used in prediction. +- GatedMaskedConv: This performs the horizontal and vertical stack convolutions, and combines them according to the graph above. +The PixelCNN scales all of the indices to a range between 0 and 1. It performs the initial convolutions (which mask the centre pixel), and then a number of gated-masked convolutions. It applies a non-linear activation function, and reshapes the output to (Batch, Classes, Channels, Height, Width). This allows for the Cross-entropy loss between the actual indices and the generated indices to be found. Note that for this model, the loss is given by the bpd (bits per dimension), which is calculated based on the cross-entropy, but is more suitable for the model. +The model also contains a function called ‘sample’. This takes an input of indices, with any indices to generate replaced with ‘-1’. It iterates through the pixels and channels and uses Softmax to predict which indices are likely. For full image generation, the indices can be replaced entirely with ‘-1’s. + +References +1. Lippe, P. (no date) Tutorial 12: Autoregressive Image modelling, Tutorial 12: Autoregressive Image Modelling - UvA DL Notebooks v1.2 documentation. Available at: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial12/Autoregressive_Image_Modeling.html (Accessed: 09 October 2023) +2. Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu (30/05/2018), Neural Discrete Representation Learning. -In the recognition folder, you will find many recognition problems solved including: -* OASIS brain segmentation -* Classification -etc. From 1d1d100cc75563d2cd99459d253936c96c23d230 Mon Sep 17 00:00:00 2001 From: DruCallaghan <141201090+DruCallaghan@users.noreply.github.com> Date: Fri, 20 Oct 2023 18:09:00 +1000 Subject: [PATCH 21/26] Update README.md --- README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f8bb97de1..2e663e481 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -Introduction +**Introduction:** In this report, a VQVAE will be developed in Pytorch to generate new images using the OASIS dataset (grayscale MRI images of the brain). A VQVAE (Vector-Quantised Variational Autoencoder) is a neural network architecture, which learns discrete latent space representations of the input. It reconstructs the original input from these representations. Discretising the latent space gives VQVAEs many advantages over the conventional VAEs, making the representations more compact and interpretable. Using an autoregressive neural network to learn the space of discrete representations, new plausible outputs are constructed by sampling from this learned space. -Hyperparameters: +**Hyperparameters:** The hyperparameters were tuned throughout the report. The learning rate which yielded the best outcome for the VQVAE was found to be 5e-4, and the batch size, 128. Scheduling was considered however wasn’t necessary for the models. The number of embeddings was chosen to be 256, and the initial embedding dimension as 32. The embedding dimension, however, was increased, as the low embedding dimension was found to be insufficient in capturing the fine details of the dataset, bottlenecking the model. This caused the loss plateauing at roughly 0.3, with the model not yet achieving the desired sharpness due to this limiting factor. -Results: +**Results:** The results of the report are shown below, the model generates plausible images, albeit extremely blurry, with much room for improvement. After 60 epochs, the VQVAE was able to produce accurate, albeit blurry, reconstructions. A graph of the losses is shown below. The individual losses (codebook loss, commitment loss and reconstruction loss) were checked numerous times to ensure expected behaviour. ![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/fac1f5a9-42ce-4461-8418-52253bac6bf6) @@ -26,7 +26,8 @@ Overfitting was also avoided here, as can be seen from the graph, however the Pi This led to the number of hidden channels in the encoder and decoder being increased slightly, as well as the embedding dimensions (if necessary, the number of embeddings can also be increased). The new model performed far better than the old, (however did not save correctly due to failure of the technology used), and was able to complete the task. The model provided is able generate far clearer and more accurate image after approximately 80 and 50 epochs (without overfitting). If even more accurate results are required, the overall number of parameters (particularly embedding dimension, number of embeddings), can be increased further. -How to Use: + +**How to Use:** In order to run this file, the environment must have to following installed: - Pytorch - TQDM @@ -43,9 +44,9 @@ Any changes to the model and data can be made in Modules and Dataset respectivel New images can be generated from the predict.py file. The function show_reconstructed_imgs(num_shown=2) shows the VQVAE input images, the codebook indices, and the reconstructions. show_generated_indices(shown_imgs=2) shows the codebook indices generated by the PixelCNN. This calls cnn_model.sample() which takes in codebook indices, and generates any which have been replaced with ‘-1’. The line ‘masked_indices[:,:,16:,:] = -1’ can be changed to alter which indices are generated. show_generated_output() shows a new generated image. Again, the masked_indices variable can be changed, so that parts of images can be generated based on existing pixels. -Data Processing: +**Data Processing:** In the dataset class, the data is loaded into datasets. The transforms applied to all data is conversion to a tensor and normalisation, based on the mean and standard deviation of the entire dataset. The number of images in the train, validation, and test dataset was of the approximate ratio 20:2:1. -Models Overview +**Models Overview:** Note – the VQVAE model described contains 32 embedding dimensions. This was the original model, but this has since been increased. VQVAE The VQVAE model can be broken up into the encoder, vector quantiser, and decoder. Throughout the process, the data is generally in the shape (Batch, Channel, Height, Width). The batch size used is 128, as this was found to yield the best results overall. A visual depiction and descriptions of the VQ-VAE components are below. From 8bc05cb96b29b868e78c5eb43a42fb300ea734d6 Mon Sep 17 00:00:00 2001 From: DruCallaghan <141201090+DruCallaghan@users.noreply.github.com> Date: Fri, 20 Oct 2023 18:09:41 +1000 Subject: [PATCH 22/26] Update README.md --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 2e663e481..246cd87d5 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ **Introduction:** + In this report, a VQVAE will be developed in Pytorch to generate new images using the OASIS dataset (grayscale MRI images of the brain). A VQVAE (Vector-Quantised Variational Autoencoder) is a neural network architecture, which learns discrete latent space representations of the input. It reconstructs the original input from these representations. Discretising the latent space gives VQVAEs many advantages over the conventional VAEs, making the representations more compact and interpretable. Using an autoregressive neural network to learn the space of discrete representations, new plausible outputs are constructed by sampling from this learned space. **Hyperparameters:** + The hyperparameters were tuned throughout the report. The learning rate which yielded the best outcome for the VQVAE was found to be 5e-4, and the batch size, 128. Scheduling was considered however wasn’t necessary for the models. The number of embeddings was chosen to be 256, and the initial embedding dimension as 32. The embedding dimension, however, was increased, as the low embedding dimension was found to be insufficient in capturing the fine details of the dataset, bottlenecking the model. This caused the loss plateauing at roughly 0.3, with the model not yet achieving the desired sharpness due to this limiting factor. **Results:** + The results of the report are shown below, the model generates plausible images, albeit extremely blurry, with much room for improvement. After 60 epochs, the VQVAE was able to produce accurate, albeit blurry, reconstructions. A graph of the losses is shown below. The individual losses (codebook loss, commitment loss and reconstruction loss) were checked numerous times to ensure expected behaviour. ![image](https://github.com/DruCallaghan/PatternAnalysis-2023/assets/141201090/fac1f5a9-42ce-4461-8418-52253bac6bf6) @@ -28,6 +31,7 @@ The new model performed far better than the old, (however did not save correctly If even more accurate results are required, the overall number of parameters (particularly embedding dimension, number of embeddings), can be increased further. **How to Use:** + In order to run this file, the environment must have to following installed: - Pytorch - TQDM @@ -44,9 +48,12 @@ Any changes to the model and data can be made in Modules and Dataset respectivel New images can be generated from the predict.py file. The function show_reconstructed_imgs(num_shown=2) shows the VQVAE input images, the codebook indices, and the reconstructions. show_generated_indices(shown_imgs=2) shows the codebook indices generated by the PixelCNN. This calls cnn_model.sample() which takes in codebook indices, and generates any which have been replaced with ‘-1’. The line ‘masked_indices[:,:,16:,:] = -1’ can be changed to alter which indices are generated. show_generated_output() shows a new generated image. Again, the masked_indices variable can be changed, so that parts of images can be generated based on existing pixels. + **Data Processing:** + In the dataset class, the data is loaded into datasets. The transforms applied to all data is conversion to a tensor and normalisation, based on the mean and standard deviation of the entire dataset. The number of images in the train, validation, and test dataset was of the approximate ratio 20:2:1. **Models Overview:** + Note – the VQVAE model described contains 32 embedding dimensions. This was the original model, but this has since been increased. VQVAE The VQVAE model can be broken up into the encoder, vector quantiser, and decoder. Throughout the process, the data is generally in the shape (Batch, Channel, Height, Width). The batch size used is 128, as this was found to yield the best results overall. A visual depiction and descriptions of the VQ-VAE components are below. From eb795388cd5cb894dc9f5eed8b54ec13281eefe2 Mon Sep 17 00:00:00 2001 From: DruCallaghan <141201090+DruCallaghan@users.noreply.github.com> Date: Fri, 20 Oct 2023 18:10:10 +1000 Subject: [PATCH 23/26] Update README.md --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 246cd87d5..0e4cc24c2 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,12 @@ In this report, a VQVAE will be developed in Pytorch to generate new images using the OASIS dataset (grayscale MRI images of the brain). A VQVAE (Vector-Quantised Variational Autoencoder) is a neural network architecture, which learns discrete latent space representations of the input. It reconstructs the original input from these representations. Discretising the latent space gives VQVAEs many advantages over the conventional VAEs, making the representations more compact and interpretable. Using an autoregressive neural network to learn the space of discrete representations, new plausible outputs are constructed by sampling from this learned space. + **Hyperparameters:** The hyperparameters were tuned throughout the report. The learning rate which yielded the best outcome for the VQVAE was found to be 5e-4, and the batch size, 128. Scheduling was considered however wasn’t necessary for the models. The number of embeddings was chosen to be 256, and the initial embedding dimension as 32. The embedding dimension, however, was increased, as the low embedding dimension was found to be insufficient in capturing the fine details of the dataset, bottlenecking the model. This caused the loss plateauing at roughly 0.3, with the model not yet achieving the desired sharpness due to this limiting factor. + **Results:** The results of the report are shown below, the model generates plausible images, albeit extremely blurry, with much room for improvement. After 60 epochs, the VQVAE was able to produce accurate, albeit blurry, reconstructions. A graph of the losses is shown below. The individual losses (codebook loss, commitment loss and reconstruction loss) were checked numerous times to ensure expected behaviour. @@ -52,6 +54,7 @@ show_generated_output() shows a new generated image. Again, the masked_indices v **Data Processing:** In the dataset class, the data is loaded into datasets. The transforms applied to all data is conversion to a tensor and normalisation, based on the mean and standard deviation of the entire dataset. The number of images in the train, validation, and test dataset was of the approximate ratio 20:2:1. + **Models Overview:** Note – the VQVAE model described contains 32 embedding dimensions. This was the original model, but this has since been increased. @@ -97,7 +100,8 @@ The neural network then learns to predict future indices. A number of classes we The PixelCNN scales all of the indices to a range between 0 and 1. It performs the initial convolutions (which mask the centre pixel), and then a number of gated-masked convolutions. It applies a non-linear activation function, and reshapes the output to (Batch, Classes, Channels, Height, Width). This allows for the Cross-entropy loss between the actual indices and the generated indices to be found. Note that for this model, the loss is given by the bpd (bits per dimension), which is calculated based on the cross-entropy, but is more suitable for the model. The model also contains a function called ‘sample’. This takes an input of indices, with any indices to generate replaced with ‘-1’. It iterates through the pixels and channels and uses Softmax to predict which indices are likely. For full image generation, the indices can be replaced entirely with ‘-1’s. -References +**References** + 1. Lippe, P. (no date) Tutorial 12: Autoregressive Image modelling, Tutorial 12: Autoregressive Image Modelling - UvA DL Notebooks v1.2 documentation. Available at: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial12/Autoregressive_Image_Modeling.html (Accessed: 09 October 2023) 2. Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu (30/05/2018), Neural Discrete Representation Learning. From 021fbd24ee06842e55568f1737bcd72adaa2309a Mon Sep 17 00:00:00 2001 From: DruCallaghan <141201090+DruCallaghan@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:44:42 +1000 Subject: [PATCH 24/26] Delete Modules directory --- Modules/PixelCNN.py | 136 -------------------------------------------- Modules/VQ_VAE.py | 129 ----------------------------------------- 2 files changed, 265 deletions(-) delete mode 100644 Modules/PixelCNN.py delete mode 100644 Modules/VQ_VAE.py diff --git a/Modules/PixelCNN.py b/Modules/PixelCNN.py deleted file mode 100644 index 9fbf984d2..000000000 --- a/Modules/PixelCNN.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.data - - -"""Autoregressive PixelCNN model""" -class PixelCNN(nn.Module): - - def __init__(self, in_channels, hidden_channels, num_embeddings): - super(PixelCNN, self).__init__() - # Equal to the number of embeddings in the VQVAE - self.num_embeddings = num_embeddings - # Initial convolutions skipping the center pixel - self.conv_vstack = VerticalConv(in_channels, hidden_channels, mask_center=True) - self.conv_hstack = HorizontalConv(in_channels, hidden_channels, mask_center=True) - # Convolution block of PixelCNN. Uses dilation instead of downscaling - self.conv_layers = nn.ModuleList([ - GatedMaskedConv(hidden_channels), - GatedMaskedConv(hidden_channels, dilation=2), - GatedMaskedConv(hidden_channels), - GatedMaskedConv(hidden_channels, dilation=4), - GatedMaskedConv(hidden_channels), - GatedMaskedConv(hidden_channels, dilation=2), - GatedMaskedConv(hidden_channels) - ]) - # Output classification convolution (1x1) - # The output channels should be in_channels*number of embeddings to learn continuous space and calc. CrossEntropyLoss - self.conv_out = nn.Conv2d(hidden_channels, in_channels*self.num_embeddings, kernel_size=1, padding=0) - - - def forward(self, x): - # Scale input from 0 to 255 to -1 to 1 - x = (x.float() / 255.0) * 2 - 1 - - # Initial convolutions - v_stack = self.conv_vstack(x) - h_stack = self.conv_hstack(x) - # Gated Convolutions - for layer in self.conv_layers: - v_stack, h_stack = layer(v_stack, h_stack) - # 1x1 classification convolution - # Apply ELU (exponential activation function) before 1x1 convolution for non-linearity on residual connection - out = self.conv_out(F.elu(h_stack)) - - # Output dimensions: [Batch, Classes, Channels, Height, Width] (classes = num_embeddings) - out = out.reshape(out.shape[0], self.num_embeddings, out.shape[1]//256, out.shape[2], out.shape[3]) - return out - - """Indices shape should be in form B C H W - Pixels to fill should be marked with -1""" - @torch.no_grad() - def sample(self, ind_shape, ind): - # Generation loop (iterating through pixels across channels) - for h in range(ind_shape[2]): # Heights - for w in range(ind_shape[3]): # Widths - for c in range(ind_shape[1]): # Channels - # Skip if not to be filled (-1) - if (ind[:,c,h,w] != -1).all().item(): - continue - # Only have to input upper half of ind (rest are masked anyway) - pred = self.forward(ind[:,:,:h+1,:]) - probs = F.softmax(pred[:,:,c,h,w], dim=-1) - ind[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) - return ind - - -"""A general Masked convolution, with a the mask as a parameter.""" -class MaskedConvolution(nn.Module): - - def __init__(self, in_channels, out_channels, mask, dilation=1): - - super(MaskedConvolution, self).__init__() - kernel_size = (mask.shape[0], mask.shape[1]) - padding = ([dilation*(kernel_size[i] - 1) // 2 for i in range(2)]) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) - - # Mask as buffer (must be moved with devices) - self.register_buffer('mask', mask[None,None]) - - def forward(self, x): - self.conv.weight.data *= self.mask # Set all following weights to 0 (make sure it is in GPU) - return self.conv(x) - - -class VerticalConv(MaskedConvolution): - # Masks all pixels below - def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): - mask = torch.ones(kernel_size, kernel_size) - mask[kernel_size//2+1:,:] = 0 - # For the first convolution, mask center row - if mask_center: - mask[kernel_size//2,:] = 0 - - super().__init__(in_channels, out_channels, mask, dilation=dilation) - -class HorizontalConv(MaskedConvolution): - - def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): - # Mask out all pixels on the left. (Note that kernel has a size of 1 - # in height because we only look at the pixel in the same row) - mask = torch.ones(1,kernel_size) - mask[0,kernel_size//2+1:] = 0 - - # For first convolution, mask center pixel - if mask_center: - mask[0,kernel_size//2] = 0 - - super().__init__(in_channels, out_channels, mask, dilation=dilation) - -"""Gated Convolutions Model""" -class GatedMaskedConv(nn.Module): - - def __init__(self, in_channels, dilation=1): - - super(GatedMaskedConv, self).__init__() - self.conv_vert = VerticalConv(in_channels, out_channels=2*in_channels, dilation=dilation) - self.conv_horiz = HorizontalConv(in_channels, out_channels=2*in_channels, dilation=dilation) - self.conv_vert_to_horiz = nn.Conv2d(2*in_channels, 2*in_channels, kernel_size=1, padding=0) - self.conv_horiz_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) - - def forward(self, v_stack, h_stack): - # Vertical stack (left) - v_stack_feat = self.conv_vert(v_stack) - v_val, v_gate = v_stack_feat.chunk(2, dim=1) - v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate) - - # Horizontal stack (right) - h_stack_feat = self.conv_horiz(h_stack) - h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat) - h_val, h_gate = h_stack_feat.chunk(2, dim=1) - h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate) - h_stack_out = self.conv_horiz_1x1(h_stack_feat) - h_stack_out = h_stack_out + h_stack - - return v_stack_out, h_stack_out \ No newline at end of file diff --git a/Modules/VQ_VAE.py b/Modules/VQ_VAE.py deleted file mode 100644 index a0a1ca3fc..000000000 --- a/Modules/VQ_VAE.py +++ /dev/null @@ -1,129 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.data - -# -------------------------------- -# VQVAE MODEL - -"""The VQ-VAE Model""" -class VQVAE(nn.Module): - - def __init__(self, num_embeddings, embedding_dim): - super(VQVAE, self).__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - - self.encoder = Encoder() - self.quantiser = Quantiser(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - self.decoder = Decoder() - - def forward(self, x): - # Input shape is B, C, H, W - quant_input = self.encoder(x) - quant_out, quant_loss, encoding_indices = self.quantiser(quant_input) - output = self.decoder(quant_out) - - # Reconstruction Loss, and find the total loss - reconstruction_loss = F.mse_loss(x, output) - total_loss = quant_loss + reconstruction_loss - - return output, total_loss, encoding_indices - - """Function while allows output to be calculated directly from indices - param quant_out_shape is the shape that the quantiser is expected to return""" - @torch.no_grad() - def img_from_indices(self, indices, quant_out_shape): - quant_out = self.quantiser.output_from_indices(indices, quant_out_shape) # Output is currently 32*32 img with 32 channels - return self.decoder(quant_out) - -"""The Encoder Model used in VQ-VAE""" -class Encoder(nn.Module): - def __init__(self, ): - super(Encoder, self).__init__() - - self.encoder = nn.Sequential( - nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Conv2d(16, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(32), - nn.ReLU(), - ) - - def forward(self, x): - out = self.encoder(x) - return out - -"""The VectorQuantiser Model used in VQ-VAE""" -class Quantiser(nn.Module): - def __init__(self, num_embeddings, embedding_dim) -> None: - super(Quantiser, self).__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.beta = 0.2 - - self.embedding = self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - - """Returns the encoding indices from the input""" - def get_encoding_indices(self, quant_input): - # Flatten - quant_input = quant_input.permute(0, 2, 3, 1) - quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) - - # Compute pairwise distances - dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) - - # Find index of nearest embedding - encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H - return encoding_indices - - """Returns the output from the encoding indices""" - def output_from_indices(self, indices, output_shape): - quant_out = torch.index_select(self.embedding.weight, 0, indices.view(-1)) - quant_out = quant_out.reshape(output_shape).permute(0, 3, 1, 2) - return quant_out - - def forward(self, quant_input): - # Finds the encoding indices - encoding_indices = self.get_encoding_indices(quant_input) - # Gets the output based on the encoding indices - quant_out = self.output_from_indices(encoding_indices, quant_input.shape) - - # Losses - commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) - codebook_loss = torch.mean((quant_out - quant_input.detach())**2) - loss = codebook_loss + self.beta*commitment_loss - - # Straight through gradient estimator for backprop - quant_out = quant_input + (quant_out - quant_input).detach() - - # Reshapes encoding indices to 'B, H, W' - encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) - - return quant_out, loss, encoding_indices - -"""The Decoder Model used in VQ-VAE""" -class Decoder(nn.Module): - def __init__(self, ) -> None: - super(Decoder, self).__init__() - - self.decoder = nn.Sequential( - nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(1), - ) - - def forward(self, x): - out = self.decoder(x) - return out \ No newline at end of file From dc0aaa3bc3ac25579c668f4e8662c95a3abd6e18 Mon Sep 17 00:00:00 2001 From: DruCallaghan <141201090+DruCallaghan@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:44:56 +1000 Subject: [PATCH 25/26] Delete VQ_VAE_46992925 directory --- VQ_VAE_46992925/DataPrep.py | 241 --------------- VQ_VAE_46992925/VQ_VAE_original | 524 -------------------------------- VQ_VAE_46992925/test.py | 50 --- 3 files changed, 815 deletions(-) delete mode 100644 VQ_VAE_46992925/DataPrep.py delete mode 100644 VQ_VAE_46992925/VQ_VAE_original delete mode 100644 VQ_VAE_46992925/test.py diff --git a/VQ_VAE_46992925/DataPrep.py b/VQ_VAE_46992925/DataPrep.py deleted file mode 100644 index 566bea6a8..000000000 --- a/VQ_VAE_46992925/DataPrep.py +++ /dev/null @@ -1,241 +0,0 @@ -import os -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms -import numpy as np -from tqdm import tqdm -from PIL import Image -from torch.utils.data import Dataset, DataLoader -from torchvision import datasets, transforms, utils -import matplotlib.pyplot as plt - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print("Torch version ", torch.__version__) - - -# ------------------------------------------------ -# Data Loader - -#path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" -path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" - - -class ImageDataset(Dataset): - def __init__(self, root_dir, transform=None): - self.root_dir = root_dir - self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith('.png')] - self.transform = transform - - def __len__(self): - return len(self.image_files) - - def __getitem__(self, idx): - image_path = os.path.join(self.root_dir, self.image_files[idx]) - image = Image.open(image_path).convert('L') # Convert to grayscale - if self.transform: - image = self.transform(image) - return image - -print("Loading data") - -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(0.13242, 0.18826) - ]) - -train_data_dir = "keras_png_slices_train/" -test_data_dir = "keras_png_slices_test/" - - -train_dataset = ImageDataset(path+train_data_dir, transform=transform) -test_dataset = ImageDataset(path+test_data_dir, transform=transform) - -# DataLoaders -# B, C, H, W -train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True) -test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True) - -# Debugging -first_batch = 0 -for batch in train_dataloader: - first_batch = batch - break -print("Shape of first batch is: ", first_batch.shape) -print("First batch - Mean: {} Std: {}".format(torch.mean(first_batch), torch.std(first_batch))) -plt.imshow(first_batch[0][0]) -plt.title("First Training image (Normalised)") -plt.gray() -plt.show() - - -print("> Data Loading Finished") - - -# ------------------------------------------------ -# Model - -class VQVAE(nn.Module): - def __init__(self, ): - super(VQVAE, self).__init__() - - self.encoder = nn.Sequential( - nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Conv2d(16, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Conv2d(16, 4, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(4), - nn.ReLU(), - ) - - self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1) # TODO FC layer?? - self.embedding = nn.Embedding(num_embeddings=256, embedding_dim=2) - self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1) - - # Commitment loss beta - self.beta = 0.2 - self.alpha = 1.0 - - self.decoder = nn.Sequential( - nn.ConvTranspose2d(4, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), - nn.Sigmoid(), - ) - - def forward(self, x): - # B, C, H, W - encoded_output = self.encoder(x) - quant_input = self.pre_quant_conv(encoded_output) - - # Quantisation - B, C, H, W = quant_input.shape - quant_input = quant_input.permute(0, 2, 3, 1) - quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) - - # Compute pairwise distances - dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) - - # Find index of nearest embedding - min_encoding_indices = torch.argmin(dist, dim=-1) - - # Select the embedding weights - quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) - - quant_input = quant_input.reshape((-1, quant_input.size(-1))) - - # Compute losses - commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) # TODO change to MSE - codebook_loss = torch.mean((quant_out - quant_input.detach())**2) - - - # Straight through gradient estimator - quant_out = quant_input + (quant_out - quant_input).detach() # Detach ~ ignored for back-prop - quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) - min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) - - # Decoding - decoder_input = self.post_quant_conv(quant_out) - output = self.decoder(decoder_input) - - # Reconstruction Loss, and find the total loss - reconstruction_loss = F.mse_loss(x, output) - total_losses = self.alpha*reconstruction_loss + codebook_loss + self.beta*commitment_loss - - # TODO ensure the losses are balanced - #print("The reconstruction loss makes up {}% of the total loss ({}/{})" - # .format(reconstruction_loss*100//(total_losses), int(reconstruction_loss), int(total_losses))) - - return output, total_losses - - -# ------------------------------------------------ -# Training - -########################## TODO THERE IS NO RECONSTRUCTION LOSS!! - -losses = [] # for visualisation - -# Hyperparams -learning_rate = 1.e-3 -num_epochs = 7 - -model = VQVAE().to(device) -print(model) - -optimiser = torch.optim.Adam(model.parameters(), learning_rate) - -for epoch_num, epoch in enumerate(tqdm(range(num_epochs))): - model.train() - for train_batch in tqdm(train_dataloader): - images = train_batch - images = images.to(device, dtype=torch.float32) - - output, total_losses = model(images) - - optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) - total_losses.backward() # Calculate grad - optimiser.step() # Adjust weights - - # Evaluate - model.eval() - - for test_batch in tqdm(test_dataloader): - images = test_batch - - images = images.to(device, dtype=torch.float32) # (Set as float to ensure weights input are the same type) - - with torch.no_grad(): - output, total_losses = model(images) - - - print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) - - losses.append(total_losses) # To graph losses - - -# ------------------------------------------------- -# Visualise - -# C, H, W -input_img = test_dataset[0][0] - -# Reshape to B, C, H, W for the model -input_img = input_img.reshape(1, 1, input_img.size(-2), input_img.size(-1)) -input_img = input_img.to(device, dtype=torch.float32) - -# DEBUGGING Print the input image shape and show it. -print("Shape of the input img is: ", input_img.shape) -#plt.imshow(input_img[0][0].cpu().numpy()) -#plt.gray() -#plt.show() - - -with torch.no_grad(): # Ensure no gradient calculation - output, _ = model(input_img) # Forward pass through the model - -print("Shape of the output img is: ", output.shape) - -# Display input and output images -plt.figure(figsize=(10, 5)) -plt.subplot(1, 2, 1) -plt.title("Input Image") -plt.imshow(input_img[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel input - -plt.subplot(1, 2, 2) -plt.title("Model Output") -plt.imshow(output[0][0].cpu().numpy(), cmap='gray') # Assuming single-channel output -plt.show() - -plt.plot(losses) -plt.title("Losses") -plt.xlabel("Num Epochs") -plt.ylabel("Loss") -plt.show() \ No newline at end of file diff --git a/VQ_VAE_46992925/VQ_VAE_original b/VQ_VAE_46992925/VQ_VAE_original deleted file mode 100644 index cc8a02a32..000000000 --- a/VQ_VAE_46992925/VQ_VAE_original +++ /dev/null @@ -1,524 +0,0 @@ -''' -VQ-VAE -Model as implemented by -https://www.youtube.com/watch?v=1ZHzAOutcnw -''' -import os -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms -import numpy as np -from tqdm.auto import tqdm -from PIL import Image -import torch.utils.data -from torchvision import datasets, transforms, utils -import matplotlib.pyplot as plt - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print("Torch version ", torch.__version__) - - -# ------------------------------------------------ -# Data Loader - -#path = "C:/Users/61423/COMP3710/data/keras_png_slices_data/" -path = "//puffball.labs.eait.uq.edu.au/s4699292/Documents/2023 Sem2/Comp3710/keras_png_slices_data/keras_png_slices_data/" - -def load_data_from_folder(name): - data = [] - #i = 0 - list_files = [f for f in os.listdir(path+name) if f.lower().endswith('.png')] - - for filename in tqdm(list_files): # tqdm adds loading bar - - image_path = os.path.join(path+name, filename) - image = Image.open(image_path).convert('L') # Convert to grayscale (single channel) - image = np.array(image) - - # Add channel - # C, H, W - image = np.expand_dims(image, axis=0) - - data.append(image) - - #if i == 25: - # return np.array(data) - #i += 1 - - return np.array(data) - -# Loading -# B, C, H, W (Numpy array) -print("> Loading Training data") -train_data = (load_data_from_folder("keras_png_slices_train/")) -print("> Loading Test data") -test_data = (load_data_from_folder("keras_png_slices_test/")) - -print("The shape of the (training) data is: ", train_data.shape) -print("The shape of the (testing) data is: ", test_data.shape) - -# Transforms and tensor -mean = np.mean(train_data) -std = np.std(train_data) - -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=0.13242, std=0.18826) -]) - -train_data = torch.stack([transform(item) for item in train_data]).permute(0, 2, 3, 1) -test_data = torch.stack([transform(item) for item in test_data]).permute(0, 2, 3, 1) - -print("The shape of the (training) data is: ", train_data.shape) -print("The shape of the (testing) data is: ", test_data.shape) - -# DataLoaders -# B, C, H, W -train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True) -test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True) - -#plt.imshow(train_data[0][0]) -#plt.title("First Training image (Normalised)") -#plt.gray() -#plt.show() - -print("> Data Loading Finished") - -# ------------------------------------------------ -# Model - - - -class Encoder(nn.Module): - def __init__(self, ): - super(Encoder, self).__init__() - - self.encoder = nn.Sequential( - nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Conv2d(16, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(32), - nn.ReLU(), - ) - - def forward(self, x): - out = self.encoder(x) - return out - - -class Quantiser(nn.Module): - def __init__(self, num_embeddings, embedding_dim) -> None: - super(Quantiser, self).__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.beta = 0.2 - - self.embedding = self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - - - def get_encoding_indices(self, quant_input): - # Flatten - quant_input = quant_input.permute(0, 2, 3, 1) - quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1))) - - # Compute pairwise distances - dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1))) - - # Find index of nearest embedding - encoding_indices = torch.argmin(dist, dim=-1) # in form B, W*H - return encoding_indices - - def output_from_indices(self, indices, output_shape): - quant_out = torch.index_select(self.embedding.weight, 0, indices.view(-1)) - quant_out = quant_out.reshape(output_shape).permute(0, 3, 1, 2) - return quant_out - - def forward(self, quant_input): - - # Get the encoding indices - encoding_indices = self.get_encoding_indices(quant_input) - - quant_out = self.output_from_indices(encoding_indices, quant_input.shape) - - #print(quant_out.shape, quant_input.shape) - - # Compute losses - commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) # TODO change to MSE - codebook_loss = torch.mean((quant_out - quant_input.detach())**2) - loss = codebook_loss + self.beta*commitment_loss - - # Straight through gradient estimator - quant_out = quant_input + (quant_out - quant_input).detach() # Detach ~ ignored for back-prop - - # Reshape encoding indices to 'B, H, W' - # TODO CURRENTLY MEANS NOTHING - encoding_indices = encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) - - return quant_out, loss, encoding_indices - - -class Decoder(nn.Module): - def __init__(self, ) -> None: - super(Decoder, self).__init__() - - self.decoder = nn.Sequential( - nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), - nn.BatchNorm2d(1), - ) - - def forward(self, x): - out = self.decoder(x) - return out - - - - -class VQVAE(nn.Module): - def __init__(self, num_embeddings, embedding_dim): - super(VQVAE, self).__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - - self.encoder = Encoder() - self.quantiser = Quantiser(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - self.decoder = Decoder() - - - def forward(self, x): - # B, C, H, W - quant_input = self.encoder(x) - quant_out, quant_loss, encoding_indices = self.quantiser(quant_input) - output = self.decoder(quant_out) - - # Reconstruction Loss, and find the total loss - reconstruction_loss = F.mse_loss(x, output) - total_loss = quant_loss + reconstruction_loss - - return output, total_loss, encoding_indices - - @torch.no_grad() - def img_from_indices(self, indices, quant_out_shape): - quant_out = self.quantiser.output_from_indices(indices, quant_out_shape) # Output is currently 32*32 img with 32 channels - return self.decoder(quant_out) - -# ------------------------------------------------ -# Training - -losses = [] # for visualisation - -# Hyperparams -learning_rate = 1.e-3 -num_epochs = 30 - -num_embeddings = 256 -embedding_dim = 32 - -model = VQVAE(num_embeddings=num_embeddings, embedding_dim=embedding_dim).to(device) -print(model) - -optimiser = torch.optim.Adam(model.parameters(), learning_rate) - - -for epoch_num, epoch in enumerate(range(num_epochs)): - model.train() - for train_batch in tqdm(train_dataloader): - images = train_batch - images = images.to(device, dtype=torch.float32) - - output, total_losses, _ = model(images) - - optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) - total_losses.backward() # Calculate grad - optimiser.step() # Adjust weights - - - # Evaluate - model.eval() - - for test_batch in (test_dataloader): - images = test_batch - - images = images.to(device, dtype=torch.float32) # (Set as float to ensure weights input are the same type) - - with torch.no_grad(): - output, total_losses, _ = model(images) - - - print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, total_losses)) - - losses.append(total_losses.cpu()) # To graph losses (TODO still in tensors) - -# ------------------------------------------------- -# Visualise - - -def plot_results(num_images): - - input_imgs = test_data[0:num_images] - input_imgs = input_imgs.to(device, dtype=torch.float32) - - # DEBUGGING - print("Shape of the input img is: ", input_imgs.shape) - - with torch.no_grad(): # Ensure no gradient calculation - output_imgs, _, encoding_indices = model(input_imgs) - - - #Debugging - print("Shape of the output img is: ", output_imgs.shape) - print("Enc indices shape is: ", encoding_indices.shape) - - - fig, ax = plt.subplots(num_images, 3) - plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0) - - ax[0, 0].set_title("Inputs") - ax[0, 1].set_title("CodeBook Indices") - ax[0, 2].set_title("Reconstruction") - - for i in range(num_images): - for j in range(3): - ax[i, j].axis('off') - ax[i, 0].imshow(input_imgs[i][0].cpu().numpy(), cmap='gray') - ax[i, 1].imshow(encoding_indices[i].cpu().numpy()) - ax[i, 2].imshow(output_imgs[i][0].cpu().numpy(), cmap='gray') - - plt.show() - - plt.plot(losses) - plt.title("Losses") - plt.xlabel("Num Epochs") - plt.ylabel("Loss") - plt.show() - -plot_results(2) - - - -# ------------- Pixel CNN -# Define the PixelConvLayer class in PyTorch - -class MaskedConvolution(nn.Module): - - def __init__(self, in_channels, out_channels, mask, dilation=1): - - super(MaskedConvolution, self).__init__() - kernel_size = (mask.shape[0], mask.shape[1]) - padding = ([dilation*(kernel_size[i] - 1) // 2 for i in range(2)]) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) - - # Mask as buffer (must be moved with devices) - self.register_buffer('mask', mask[None,None]) - - def forward(self, x): - self.conv.weight.data = self.conv.weight.data.to(device) * self.mask.to(device) # Set all following weights to 0 (make sure it is in GPU) - return self.conv(x) - - -class VerticalConv(MaskedConvolution): - # Masks all pixels below - def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): - mask = torch.ones(kernel_size, kernel_size) - mask[kernel_size//2+1:,:] = 0 - # For the first convolution, mask center row - if mask_center: - mask[kernel_size//2,:] = 0 - - super().__init__(in_channels, out_channels, mask, dilation=dilation) - -class HorizontalConv(MaskedConvolution): - - def __init__(self, in_channels, out_channels, kernel_size=3, mask_center=False, dilation=1): - # Mask out all pixels on the left. (Note that kernel has a size of 1 - # in height because we only look at the pixel in the same row) - mask = torch.ones(1,kernel_size) - mask[0,kernel_size//2+1:] = 0 - - # For first convolution, mask center pixel - if mask_center: - mask[0,kernel_size//2] = 0 - - super().__init__(in_channels, out_channels, mask, dilation=dilation) - - -class GatedMaskedConv(nn.Module): - - def __init__(self, in_channels, dilation=1): - - super(GatedMaskedConv, self).__init__() - self.conv_vert = VerticalConv(in_channels, out_channels=2*in_channels, dilation=dilation) - self.conv_horiz = HorizontalConv(in_channels, out_channels=2*in_channels, dilation=dilation) - self.conv_vert_to_horiz = nn.Conv2d(2*in_channels, 2*in_channels, kernel_size=1, padding=0) - self.conv_horiz_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0) - - def forward(self, v_stack, h_stack): - # Vertical stack (left) - v_stack_feat = self.conv_vert(v_stack) - v_val, v_gate = v_stack_feat.chunk(2, dim=1) - v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate) - - # Horizontal stack (right) - h_stack_feat = self.conv_horiz(h_stack) - h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat) - h_val, h_gate = h_stack_feat.chunk(2, dim=1) - h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate) - h_stack_out = self.conv_horiz_1x1(h_stack_feat) - h_stack_out = h_stack_out + h_stack - - return v_stack_out, h_stack_out - -class PixelCNN(nn.Module): - - def __init__(self, in_channels, hidden_channels): - super().__init__() - - # Initial convolutions skipping the center pixel - self.conv_vstack = VerticalConv(in_channels, hidden_channels, mask_center=True) - self.conv_hstack = HorizontalConv(in_channels, hidden_channels, mask_center=True) - # Convolution block of PixelCNN. Uses dilation instead of downscaling - self.conv_layers = nn.ModuleList([ - GatedMaskedConv(hidden_channels), - GatedMaskedConv(hidden_channels, dilation=2), - GatedMaskedConv(hidden_channels), - GatedMaskedConv(hidden_channels, dilation=4), - GatedMaskedConv(hidden_channels), - GatedMaskedConv(hidden_channels, dilation=2), - GatedMaskedConv(hidden_channels) - ]) - # Output classification convolution (1x1) - # The output channels should be in_channels*number of embeddings to learn continuous space and calc. CrossEntropyLoss - self.conv_out = nn.Conv2d(hidden_channels, in_channels*num_embeddings, kernel_size=1, padding=0) - - - def forward(self, x): - # Scale input from 0 to 255 to -1 to 1 - x = (x.float() / 255.0) * 2 - 1 - - # Initial convolutions - v_stack = self.conv_vstack(x) - h_stack = self.conv_hstack(x) - # Gated Convolutions - for layer in self.conv_layers: - v_stack, h_stack = layer(v_stack, h_stack) - # 1x1 classification convolution - # Apply ELU (exponential activation function) before 1x1 convolution for non-linearity on residual connection - out = self.conv_out(F.elu(h_stack)) - - # Output dimensions: [Batch, Classes, Channels, Height, Width] (classes = num_embeddings) - out = out.reshape(out.shape[0], num_embeddings, out.shape[1]//256, out.shape[2], out.shape[3]) - return out - - """Indices shape should be in form B C H W - Pixels to fill should be marked with -1""" - @torch.no_grad() - def sample(self, ind_shape, ind=None): - # Create tensor of indices (all -1) - if ind is None: - ind = torch.zeros(ind_shape, dtype=torch.long).to(device) - 1 - # Generation loop (iterating through pixels across channels) - for h in range(ind_shape[2]): # Heights - for w in range(ind_shape[3]): # Widths - for c in range(ind_shape[1]): # Channels - # Skip if not to be filled (-1) - if (ind[:,c,h,w] != -1).all().item(): - continue - # Only have to input upper half of ind (rest are masked anyway) - pred = self.forward(ind[:,:,:h+1,:]) - probs = F.softmax(pred[:,:,c,h,w], dim=-1) - ind[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) - return ind - - -# ------------------------------- -# Training PixCNN - -num_epochs = 30 - -cnn_model = PixelCNN(in_channels=1, hidden_channels=128).to(device) -optimiser = torch.optim.Adam(cnn_model.parameters(), learning_rate) - -# For getting codebook indices -encoder = model.__getattr__("encoder") -quantiser = model.__getattr__("quantiser") -decoder = model.__getattr__("decoder") - -for epoch_num, epoch in enumerate(range(num_epochs)): - - cnn_model.train() - - for train_batch in train_dataloader: - - # Get the quantised outputs - with torch.no_grad(): - encoder_output = encoder(train_batch.to(device)) - _, _, indices = quantiser(encoder_output) - indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) - - output = cnn_model(indices) - - # Compute loss - nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood - bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension - loss = bpd.mean() - - optimiser.zero_grad() # Reset gradients to zero for back-prop (not cumulative) - loss.backward() # Calculate grad - optimiser.step() # Adjust weights - - cnn_model.eval() - - for test_batch in test_dataloader: - # Get the quantised outputs - with torch.no_grad(): - encoder_output = encoder(test_batch.to(device)) - _, _, indices = quantiser(encoder_output) - indices = indices.reshape(indices.size(0), 1, indices.size(1), indices.size(2)).to(device) - - - output = cnn_model(indices) - #print("Indices is shape: ", indices.detach().cpu().numpy().shape) - #print("Output is shape: ", output.detach().cpu().numpy().shape) - - - # Compute loss - nll = F.cross_entropy(output, indices, reduction='none') # Negative log-likelihood - bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1)) # Bits per dimension - loss = bpd.mean() - - print("Epoch {} of {}. Total Loss: {}".format(epoch_num, num_epochs, loss)) - -with torch.no_grad(): - # Show one image - print(" > Showing Images") - #print("Real indices shape: ", indices.detach().cpu().numpy().shape) - gen_indices = cnn_model.sample((1, 1, 32, 32)) - #print("Gen indices shape: ", gen_indices.detach().cpu().numpy().shape) - - fig, ax = plt.subplots(2, 2) - - for a in ax.flatten(): - a.axis('off') - - ax[0, 0].set_title("Real Indices") - ax[0, 0].imshow(indices[0][0].long().cpu().numpy(), cmap='gray') - ax[0, 1].set_title("Real Decoded") - ax[0, 1].imshow(model.img_from_indices(indices, quant_out_shape=(26, 32, 32, 32))[0][0].cpu().numpy(), cmap='gray') - - ax[1, 0].set_title("Generated Indices") - ax[1, 0].imshow(gen_indices[0][0].cpu().numpy(), cmap='gray') - ax[1, 1].set_title("Generated Image") - ax[1, 1].imshow(model.img_from_indices(gen_indices, (1, 32, 32, 32))[0][0].cpu().numpy(), cmap='gray') - plt.show() \ No newline at end of file diff --git a/VQ_VAE_46992925/test.py b/VQ_VAE_46992925/test.py deleted file mode 100644 index 1c5328d4f..000000000 --- a/VQ_VAE_46992925/test.py +++ /dev/null @@ -1,50 +0,0 @@ -import os -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms -import numpy as np -from tqdm import tqdm -from PIL import Image -import torch.utils.data -from torchvision import datasets, transforms, utils -import matplotlib.pyplot as plt - -conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3) - -weights = conv.weight.data -biases = conv.bias.data - -print("Weights: ", weights) -print("Biases: ", biases.shape) - - -class MaskedConv2d(nn.Conv2d): - - - def __init__(self, num_channels, kernel_size): - - super(MaskedConv2d, self).__init__(num_channels, num_channels, kernel_size=kernel_size, padding=(kernel_size//2)) - - #self.register_buffer('mask', torch.zeros_like(self.weight)) - - k = self.kernel_size[0] - - self.weight.data[:, :, (k//2+1):, :].zero_() - self.weight.data[:, :, k//2, k//2:].zero_() - - - def forward(self, x): - k = self.kernel_size[0] - # Type 'A' mask - self.weight.data[:, :, (k//2+1):, :].zero_() - self.weight.data[:, :, k//2, k//2:].zero_() - - out = super(MaskedConv2d, self).forward(x) - return out - - - -masked_conv = MaskedConv2d(num_channels=1, kernel_size=5) - -print("Masked weights: ", masked_conv.weight.data) \ No newline at end of file From 48472920621c4fed090d6bf471c96a5309bfa398 Mon Sep 17 00:00:00 2001 From: DruCallaghan <141201090+DruCallaghan@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:45:22 +1000 Subject: [PATCH 26/26] Delete __pycache__ directory --- __pycache__/dataset.cpython-39.pyc | Bin 2534 -> 0 bytes __pycache__/modules.cpython-39.pyc | Bin 8167 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 __pycache__/dataset.cpython-39.pyc delete mode 100644 __pycache__/modules.cpython-39.pyc diff --git a/__pycache__/dataset.cpython-39.pyc b/__pycache__/dataset.cpython-39.pyc deleted file mode 100644 index a7aefde712ca7610679370bd9f9a63af85d22ddb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2534 zcmcImON$#v5bo}IY9#HhZbE+jrUko*GaX!0j=%x~yxAYsqG@!2)l=PFU0q#|uN5|%9)Z@{ z|Ab$x6Y?7ljy3}ZH-V%NT||5m(TF9~r~1o$roTpN`X;T`EZ>GXBXLsKclDf^)Kbs) z=IeEzeL*5CvJXgP?=!yUZ2$MXFxR9;9-x`Q; zPxiuux8gJ$ps>hO*pZ;XA(%)u+fGKD7fg^uyN69z&JqYDfv5e1({MHHUntL>Gd5r| z`jnJbNdmJZQ>&yUo3V*K=MCVUId1~*7WUNnj`V2+mK?aHeTVGR8J)V39XXL3)n=e2 zs4b;DE#(|5FlcNn>;HwDI$k_h0;S zW6=EJr*Hl^`|FK1Q)bG;Oc}C>dK8?oXy|hj=p>MQ6S@H@$%Dp(eosoePk&^)<`#KQ z1bXs0Q>?SD%(&p`bE;~E2r~&;k?Pey<8F62>h~XpNzzHehqA-Nxaf>_J3Jb7!ckX# zdgbcXjjJ2oJ9%%E@~n{EjrEPotu3BzbZ_VB@bc%E*1KQxVcxwL76ljCtt?FTWGpYD zTo+SL@}e``Q{J^CmPHTtTyKG4FM!CneKZw1K@ewg5d@uUNe+oAuLHB#BhUkfc=bcf zjzb*&cdzJMSDgzA1e9szC(%T11-kc;S%vO@(SClRy}hQqWiX_2(3!d9%Dtc8=b6lf zs%=7;!XzGZNreUzx+O*!CNh zmdaHB#9Tw1XLIZQy53){i!e%e-VUjHx=9&d`8%Vwl9pDh%VLJw34%smgk)9b7RwC{i zoA&`I9k~q^mGHJD-U4}Ljfyyt%EX8%4^OG&LKTL5Y2wb1p@^hHBpU0$Pv!GKHvPA4V;>Vbw?1kp2OMiI~^R>4YZ zANU7fPHcqUsadU;HiBzmP3_VF2z4TJMgd^mX{~fCZpSNKaEtW}fVVWJb)2VU_kxx& zW)vJ{Z|Z9PM}eV*TR`$%Z6Sc#=?N|Al+556h`=@hR#foTrem`+%m#pxR+2yFLb*JP z|Y1cPX*S4$LSY3Ez5GXGQ(mWa^$TxyucN8Yoiue!?5}2t1k3C`vK1LPDV)~u{lKAqw z%su1U0_WOIL_kuuNf^k(or!}rUiY1^$=OthYtZ-#zqjb0@t{`8-KPvEMfjq`R r8TY=rTSZ_^tLW%D>|lNA+Zf=kc9$ZUqYq6$8)W?|&&ec+oI^%Et6#qj4X<_{S(hqhW+9G&`!Hpj5*eD7Boj0LnsG2c@1<>Yyx!OQ0-8b?mf=_N8zc?PX~%p?x`AL3<^iaVaz( z8f(uDKl`BHOM-SUUVZoNcVB&dHRy$_Z?&I9od+8mtKEJWb>cI;y|K|`IsI69BmOcf zkpZq44f8{zCER)HHLUFI8+POzyC?aZAF4*_L&G*29(FIKj%?a2mg0)6lkxnI@1aPH zM2(C?^VmEAi$|#Y*2qL@C(faPe>W+CYKH2FR4*wDwGtZG!rVAD(|T{v^`q`q6o&2I zcAPGZ>wehou2oYf9_&S0o7gS&e81gmC%&(F3hj_qdcD^KD|A|l@=E#bLC{OuF=*xe zakcbXl#dIu%Bhvwanxx|);d*b+4q~BAdY?i)cBwF=H}j@)!GU=o%K$z6|YA@J6Rv> zuSelvJs51pU;fJ7yLay1*?g_v9CV{z5^vtQb?2qkhf(*==7WBB@1?KazP0&IwAbH! zD@c+^_g?JrVpuQUyh%nB zkB165@c=)O$H+K_6i(CuxP1)1pQugiz&fxG9JI}mg&sGtj%~=qJuweTBX?wvN+YLb zKQvC&#+swcSV~t}ETBl$H&U+~M}DUt$7`PE5mWUftqNZujJ-jUGuiK>TFK}foi2$y zAXE$s+s!CWT@2xb%Ni2s#fcs?lXkx+i|b|XDaB=OQfl?}`+*LxqY(=fm=(38Dpt)L zE@UjpN6cs_xiRL*H8e)XE;&J7mj8TABLY3%Zf-rxoANv;*~; z?_<|a)+3stKaHhQC+X|vPFm{q{cRnD`WH!D$Q8I5bvm$<30IfUtI1k1B?>pD`gmro z;w`vJs7(-N9vV8I_;V;e4o3tx92m8uxoT5!R{pR~#SzpPp)!ja)W^iybqL&qCHy-_ z4ybMj{l$Oj#E^dT6Z_qB^W9?pNtx&)s8iZFweHDR;X3&*4%Ibn~Hr%;oO^esLYe6coUl zWa_T6Q&`z4u3}VbL6xVpL==SDMRdnv76bGzv7l~D)Kb5|<~+Tl zV!)MW%X+G|n5haVGmqE_DDkhMD0Gg>@z+W`82xq)9!$J-!_Itk!^u2z!xc>|K@+Q* zu(1J9cXgBubZ-VI$Sdx|f7xkPDW_doNVn!B?r!GF%p=pLYxBk`Wo>N#;IH`n@DGs&83q>V?jV7s=|~V2 zP8kNMKaZiQ(+t}%#Ftrj-;dhcI|+rk7X|na^me!1%Mn`BbW&S}e+tzp1r&F(vK=M< zd582S2IjtHN!_&8Er9cewcLoxlN)tYe4H z7?7p}q8uG>WmPk>>?ytF|Ha1E|d==t{EJIHUi7zQivet*bAX_ z@(-ShvVw|G#%xjT>FXNg{E(I|TQ#$vPI?Y0vJdlI%Z78q1+_4|*cnc56}&MBXG>)*xfn7D746|-(# z0lVwgl6BJ@UYamHM^^+|^U!tt60+dMI7`@hmD|zuJ+%z~?l}U6O%P@tYA~Pp3n)Gr zYsX%`q#37fXhQVGzWy94Sue5@uw@<&tt7zZvV_$tQCQ?fteT#>0b?P9-Nt@Pd5Y4Qr<@P!M-Txd*P#QPC# z;qmzVJXxhzB(Tjw_qZbKOggwj{Ff+Z)dtp$#4@2ff~yM6Bevm@XWFI2J1!@ck$I?C zQaz5$skU>jJ=MQm0-hLtHSu<9NMDKJ)(UO{7+W0K=vk1^vAprEHI-JoLHx*X0s)Yg z(zn5xd11Q~Fqg@6P1KCO@6^-`?DUsVtktvV@!KR2IwkQ;Ywt!nX*YuowJxn9F?#Gj zj)6#z()w@ob$iG=fa%U1>C8fi$~Ps2d=p+WU3?=*qVStITog{N)d@a`bgW;(hN&m3 z1KNwcrrPfL&7Gi!xHnE0ciLeXO$b?HJV(+09BailQAna{E~zC%tS(e;{QGaOI=oa! zdcMBX?GT_u-@k_noPmL>2=yr=G{Jo9nt}L4a<@F6g31Ssz*9-0I%md3(x31&592A4BzzY8WI+Bm_~i~{#j2z2N&-%Y z>bh~Lz{S5mu%r!6Nc(T$n%NI7N&5%jikZcJaL(4qZh@Df{m432C%92KG4wn8FM{uI z&e^D<9N~)TDk8=_Zp17|#7*p7XC%>N(UY(0a8=}exRa1Ty9Dz^BuUwUO0$>vvpB;& zWB; zm2`tAS=d$W4mwG@*Y7g1E5*q1qG43rfs_%GWj6q+g=(sHQuThSnyI~~qcC;$bbsr7 z9Dt-!j`4RmSWF2U7EA}mv#Lz<3+i*~MI>5dS!QG5r{;wMR0uhxex9`_EM8zSQJe|7 z6L9_vzarni;;kUxC$)*+%tJNh6LZ;dJb;`EH-|vfr)A__5Xe)bVNYU#^wI=$`kjHm zAQ7~g}IO0u{CoNmXT(KeQhw7-Sx%t#kNw8bIJ6vjS4{{iRA>&Tv>L?hDpPO{HGppqv|b+y90!5ZGU z__bz|26CMWAXr2l1}Yo7%0|jWRKIJOhR>TzAkc{_E-h^RhZvcsJYj(BnTG^d7D?tM z#+25}F}d|idZoola6-dkh*1TtoY<+*KUwT`h zC*~oTON+r=MnQIN9nMsOo^j!~u-&Ph;h){_MR69g@JOe5LPVv}f6PK6n|E0gg%Q;e zg&7-xD9k@&M5ZvRrfwoqSTTpsO$R+=6To%sD;S*dGgrre53=9$9c zpHoVk+-q{8i={-ulVio4Bk&e!ap=8C=KkDhuC^=yE02#^nrJGqb&evds4K*=3o{85oJ# zl@U@j%&P#BDlJBF#kGyFTVgj)^MJSn+!^#9<(QNNJ$&(q+uPlKJJjDtzy1>ze~JR4 z5!P*!u!Eve%Y$scg&V_$gTyN6G~CDjHZr9~wdjNcY}E27JRiA}joisc7M&P5Wg}x> z^E2!f^LrBRc5r9U@Vw4{WJDh9@4R6Czn+IG_U~Mp0)yxA#9-Ez(~|Fp{U*Ndd>$RC zC!c}Vd%Y|ryj(QbTZ5i_Uqkp>nbfoNWGNekV679c)2XIzmP*QZpzI4r1*^+NfW(U8 z_A