diff --git a/README.md b/README.md index 4a064f841..0bcb39dfc 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,23 @@ # Pattern Analysis -Pattern Analysis of various datasets by COMP3710 students at the University of Queensland. - -We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. - -This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. - -The library includes the following implemented in Tensorflow: -* fractals -* recognition problems - -In the recognition folder, you will find many recognition problems solved including: -* OASIS brain segmentation -* Classification -etc. +----------------Detecting the cancer position with 2D UNet ---------------- +UNet is a combination of Encoder and Decoder to solve multiple image problems such as object detection (this project) and image generation. +The structure is as below. +![Example Image](unet_img.png) + (Aramendia. I., 2024) +The first half part is the AutoEncoder to extract the latent space from the original input images with 3 layers of double convolutional networks followed by max pooling to down sample the image size. At the bottle neck layer, the network retains the most extracted features of the original input to generate the images based on. It then utilizes up sampling by reversing the operations done in the first half part. During that, the images processed by only part of the down sampling networks are combined with the input so that the information from the original images are retained as much as possible during the feature extraction. +---------------Code structure------------------------------------------------------ +I first set up the UNet network and activated the train mode. During the training process the model goes through these processes. +1. Iterate over the same train and validation dataset 40 times. +2. For each iteration, iterate over train dataset batches divided by DataLoader to update the model parameters frequently. +3. Then, switch to evaluation mode and record the loss function for validation dataset. +4. After all epochs are done, use test dataset to calculate final test loss value and Dice Similarity Coefficient. +----------------Optimization-------------------------------------------------------- +The BCE loss function is used to calculate the cost of each iteration, which is calculated by this. + ![Example image](bceloss.png) + y: true label y_hat: predicted +In order to update the parameters, Stochastic Gradient Descent was used. +---------------How to use the code---------------------------------------------- +1. Run train.py +2. Load test dataset form train.py or any other dataset from the same image group and feed it to predict() in predict.py. +-------------Reference----------------------------------------------------- +Aramendia, I. (2024). The U-Net: A Complete Guide. https://medium.com/@alejandro.itoaramendia/decoding-the-u-net-a-complete-guide-810b1c6d56d8 diff --git a/bceloss.png b/bceloss.png new file mode 100644 index 000000000..02aa508db Binary files /dev/null and b/bceloss.png differ diff --git a/dataset.py b/dataset.py new file mode 100644 index 000000000..9764a36c8 --- /dev/null +++ b/dataset.py @@ -0,0 +1,237 @@ +import torch +import os +from PIL import Image +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import matplotlib.pyplot as plt +import numpy as np +import nibabel as nib +from tqdm import tqdm +from typing import List +import os + +#from sklearn.utils import shuffle + +KERAS_PATH = '/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices' +BATCH_SIZE = 12 + +def to_channels(arr : np.ndarray , dtype = np.uint8 ) -> np.ndarray : + ''' + Add number of channels as a new dimension. The output will be + (n_channels, height, width) to fit to the UNet expectation. + ''' + channels = [0, 1, 2, 3, 4, 5] + res = np.zeros (arr.shape + ( len(channels),), dtype = dtype) + for c in channels: + c = int(c) + res [..., c:c+1][arr == c] = 1 + #res = np.transpose(res, (2, 0, 1)) + return res + +# load medical image functions +def load_data_2D(imageNames, normImage=False, categorical=False, dtype=np.float32, getAffines=False, early_stop=False): + ''' + Load medical image data from names, cases list provided into a list for each. + + This function pre-allocates 4D arrays for conv2d to avoid excessive memory usage. + + normImage : bool (normalise the image 0.0-1.0) + early_stop : Stop loading pre-maturely, leaves arrays mostly empty, for quick loading and testing scripts. + ''' + affines = [] + + # Initialize arrays based on the first image to get size information + num = len(imageNames) + example_image = nib.load(imageNames[0]).get_fdata(caching='unchanged') + if len(example_image.shape) == 3: + example_image = example_image[:, :, 0] # Remove extra dimension if present + + if categorical: + example_image = to_channels(example_image, dtype=dtype) + channels, rows, cols = example_image.shape + images = np.zeros((num, channels, rows, cols), dtype=dtype) + else: + rows, cols = example_image.shape + images = np.zeros((num, 1, rows, cols), dtype=dtype) # Initialize with a single channel + + for i, inName in enumerate(tqdm(imageNames)): + niftiImage = nib.load(inName) + inImage = niftiImage.get_fdata(caching='unchanged') # Read from disk only + affine = niftiImage.affine + if len(inImage.shape) == 3: + inImage = inImage[:, :, 0] # Remove extra dimensions if present + inImage = inImage.astype(dtype) + if normImage: + inImage = (inImage - inImage.mean()) / inImage.std() + if categorical: + inImage = to_channels(inImage, dtype=dtype) + images[i, :, :, :] = inImage # Assign the data to the pre-allocated array + else: + images[i, 0, :, :] = inImage # Ensure consistent shape for non-categorical data + + affines.append(affine) + if i > 20 and early_stop: + break + + if getAffines: + return images, affines + else: + return images # Return 4D data [number of images, channels, rows(height), cols(width)] + + +def data_generator(image_files, label_files, batch_size): + ''' + As DataLoader is not applicable for different shape of images and labels, + data_generator works instead. + Parameters: + image_files: Image file paths + label_files: Segmented image file paths + batch_size: batch + ''' + num_samples = len(image_files) + indices = np.arange(num_samples) + np.random.shuffle(indices) + + for start_idx in range(0, num_samples, batch_size): + end_idx = min(start_idx + batch_size, num_samples) + batch_indices = indices[start_idx:end_idx] + + image_batch = [image_files[i] for i in batch_indices] # Add a new dimension for 1 channel + label_batch = [label_files[i] for i in batch_indices] + + image_batch = torch.tensor(np.stack(image_batch), dtype=torch.float32) + label_batch = torch.tensor(np.stack(label_batch), dtype=torch.float32) + + yield image_batch, label_batch + +class CustomImageDataset(Dataset): + ''' + Class to define the images in specified path as well as their segmented version. + ''' + def __init__(self, img_dir: str, img_type: str, transform=None): + """ + Prepare the image paths and load data from nbi library. + Parameters: + img_dir: Image directory + transform: Transformations to apply. + Return: + 3D tensors (Images, Labels) + (n_channels, height, width) + """ + self.img_dir = img_dir #Image directory. + self.img_type = img_type #Train, validate or test. + self.transform = transform #The transform method + self.img_files = list() + self.seg_img_files = list() + data_list = os.listdir(self.img_dir + self.img_type[:-1]) + #Filenames fetched later + self.img_files = [self.img_dir + self.img_type + data for data in data_list] + data_list_seg = os.listdir(self.img_dir + '_seg' + self.img_type[:-1]) + self.seg_img_files = [self.img_dir + '_seg' + self.img_type + data for data in data_list_seg] + + def __len__(self): + return len(self.img_files) + + def __getitem__(self, idx): + #We return segmented image as label. + + img = load_data_2D([self.img_files[idx]], normImage=True)[0] #256x128 + img = np.transpose(img, (0, 1, 2)) + label = load_data_2D([self.seg_img_files[idx]], categorical=True)[0] #256x128 + if self.transform: + image = self.transform(img) + label = self.transform(label) + print(image.shape, label.shape) + return image, label + ''' + img_path = self.img_files[idx] + niftiImage = nib.load(img_path) #Load image from zs file. + inImage = niftiImage.get_fdata(caching='unchanged') #Pixels in ndarray + if len(inImage.shape) == 3: + inImage = inImage[:, :, 0] + label = self.seg_img_files[idx] + if not label: + label = to_channels(inImage) + if self.transform: + image = self.transform(inImage) + label = self.transform(label) + + return image, label + ''' + +def keras_dataloader(image_size: int, data_type: str): + ''' + Create DataLoader for the CustomImageDataset defined above. + Run this function three times for train, validation and test set. + Returns both DataLoader and Dataset. + ''' + dataloader = dict() + dataset = dict() + transformation = transforms.Compose( + [ transforms.ToTensor() + + #transforms.Normalize( + # [0.5 for _ in range(CHANNELS_IMG)], + # [0.5 for _ in range(CHANNELS_IMG)], + #), + #0.5 for white/black. Random normalized values for colored. + + ] + ) + dataset = CustomImageDataset(KERAS_PATH, data_type, transform = transformation) + # DataLoader with pin_memory and num_workers for faster data loading + #DataLoader requires the passed param to be preprocessed as img class w/ + #__getitem__ and __len__. + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE,\ + shuffle=True, num_workers=4) + return dataloader, dataset + + + +transformation = transforms.Compose( + [ transforms.ToTensor(), + #transforms.RandomHorizontalFlip(p=0.5) + #transforms.Normalize( + # [0.5 for _ in range(CHANNELS_IMG)], + # [0.5 for _ in range(CHANNELS_IMG)], + #), + #0.5 for white/black. Random normalized values for colored. + ] + ) +# only if my brain comes back + +#Debug + +a = CustomImageDataset(KERAS_PATH, '_train/', transform=transformation) + +def check_shapes(images, labels): + for i, img in enumerate(images): + if img.shape != (1, 256, 128): + print(f"Image shape error at index {i}: {img.shape}") + for i, lbl in enumerate(labels): + if lbl.shape != (6, 256, 128): + print(f"Label shape error at index {i}: {lbl.shape}") + +def check_shape(images, labels): + prev_i = images[0] + for i, img in enumerate(images): + if prev_i.shape != img.shape: + print(f'{i}th im shape {img.shape} from {prev_i.shape}') + prev_i = img + prev_l = labels[0] + for i, lbl in enumerate(labels): + if prev_l.shape != lbl.shape: + print(f'{i}th lbl shape {lbl.shape} from {prev_l.shape}') + prev_i = img + +#check_shape(a[0], a[1]) +#CustomImageDataset return tuples of (im, lbl) +#a[0] -> ims a[1] -> lbls +print(a[0]) +#print(a[0][0], a[0][1], a[0][2]) +#print(a[1][0], a[1][1], a[1][2]) + +#b = DataLoader(a, 12, shuffle = True) +#for images, labels in b: +# print(images.shape) +# print(labels.shape) \ No newline at end of file diff --git a/modules.py b/modules.py new file mode 100644 index 000000000..ae1e36d11 --- /dev/null +++ b/modules.py @@ -0,0 +1,103 @@ +import torch +from torch import nn + +class DoubleConv(nn.Module): + """DoubleConv is a basic building block of the encoder and decoder components. + Consists of two convolutional layers followed by a ReLU activation function. + """ + def __init__(self, in_channels, out_channels): + super(DoubleConv, self).__init__() + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + x = self.double_conv(x) + return x + + +class Down(nn.Module): + """Downscaling. + Consists of two consecutive DoubleConv blocks followed by a max pooling operation. + """ + def __init__(self, in_channels, out_channels): + super(Down, self).__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + x = self.maxpool_conv(x) + return x + + +class Up(nn.Module): + """Upscaling. + Performed using transposed convolution and concatenation of feature maps from the corresponding "Down" operation. + """ + def __init__(self, in_channels, out_channels, bilinear=True): + super(Up, self).__init__() + + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + + # input tensor shape: (batch_size, channels, height, width) + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + return x + + +class UNet(nn.Module): + def __init__(self, num_channels=1, num_classes = 6, bilinear=False): + ''' + Construct a typical 2D UNet. + Parameters: + num_channels: Number of factors to represent colors. 1 = greyscale, 3 = RGB + num_classes: The number of classes. Each layer represents the probability of pixels + falling into corresponding class. + ''' + super(UNet, self).__init__() + self.inc = (DoubleConv(num_channels, 64)) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + factor = 2 if bilinear else 1 + + self.down4 = Down(512,1024 // factor) + self.up1 = Up(1024, 512 // factor, bilinear) + self.up2 = Up(512, 256 // factor, bilinear) + self.up3 = Up(256, 128 // factor, bilinear) + self.up4 = Up(128, 64, bilinear) + self.outc = nn.Conv2d(64, num_classes, 1) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + x = self.outc(x) + x = torch.sigmoid(x) #The probability of each class per pixel + return x \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 000000000..b027039bb --- /dev/null +++ b/test.py @@ -0,0 +1,37 @@ +import train +import modules +import dataset +import torch +from PIL import Image +import numpy as np + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +MODEL_PATH = '/home/Student/s4722435/miniconda3/envs/new_torch/unet_stuff/model.pth' + +def predict(loader): + ''' + Produce the segmented image of given dataset (test dataset) and + show the Dice Similarity Coefficient. + ''' + model = modules.UNet().to(DEVICE) + model.load_state_dict(torch.load(MODEL_PATH, weights_only=True)) + model.eval() + dsc = 0.0 + model.eval() + for idx, (raw, seg) in loader: + imgs, labels = raw.to(DEVICE), seg.to(DEVICE) + outputs = model(imgs) + output = outputs.squeeze().cpu().numpy() + output = (output * 255).astype(np.uint8) + output_img = Image.fromarray(output) + output_img.save('seg_picture_' + str(idx) + '.png') + dsc = train.dice_similarity_coeff(outputs, labels) + dsc = dsc // len(loader) + print('Disc Similarity Coefficient', dsc) + +def main(): + test_set = dataset.keras_dataloader(256) + predict(test_set[0]) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 000000000..4c22fe862 --- /dev/null +++ b/train.py @@ -0,0 +1,167 @@ +import dataset as dst +import modules as mm +import torch +from torch import nn +import torch.optim as optim +import torch.nn.functional as F +import matplotlib.pyplot as plt + +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +IMAGE_SIZE = 128 +NUM_EPOCHS = 32 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +LEARNING_RATE = 0.001 +KERAS_PATH = '/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices' + +NUM_CHANNELS = 1 +NUM_CLASSES = 1 +NUM_LEVELS = 3 + +WEIGHT_DECAY = 0.0005 +MOMENTUM = 0.9 + +def dice_similarity_coeff(pred, target, smooth=1e-6): + ''' + Calculate Dice Similarity Coefficient by calculating + the number of matched pixels / total pixels. + Test loss function. + ''' + pred = torch.sigmoid(pred) # Convert output into probability + intersection = (pred * target).sum(dim=(2, 3)) + union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3)) + + dice = (2. * intersection + smooth) / (union + smooth) + return dice.mean() + +#Visualization of loss +def plot_line_graph(data, title:str, xlabel:str, ylabel:str, filename:str): + """ + Plot 2D graph with xlabel = n_epoch (train and validation) or n_batch (test) + and ylabel = corresponding loss value. + Use savefig to save the image in the directory. + + Parameters: + data (list or array): values to plot + title (str) + xlabel (str) + ylabel (str) + filename (str): Png file's name to save as. + """ + # domain 0 to n-1 + x = list(range(len(data))) + # plotting + plt.plot(x, data, marker='o') + # Title and labels + plt.title(title) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.show() + plt.savefig(filename) + +def criterion(criteria, outputs, labels, n_classes=6): + ''' + Calculate the Dice Similarity Coefficient per class. + Use dst.to_channels() to separate masking for each class and calculate + one by one. + ''' + + +def main(): + + train_loader, train_set = dst.keras_dataloader(IMAGE_SIZE, '_train/') + validation_loader, validation_set = dst.keras_dataloader(IMAGE_SIZE, '_validate/') + test_loader, test_set = dst.keras_dataloader(IMAGE_SIZE, '_test/') + + # For Visualization + train_loss_list = [] + valid_loss_list = [] + test_loss_list = [] + + #Define UNet model, loss function and update optimizer. + model = mm.UNet().to(DEVICE) + criteria = nn.CrossEntropyLoss().to(DEVICE) #BCE Loss accepts softmaxed tensor while CrossEntropy not. + optimizer = optim.SGD(model.parameters(), weight_decay=WEIGHT_DECAY, lr = LEARNING_RATE, momentum=MOMENTUM) + + for epoch in range(NUM_EPOCHS): + print(str(epoch + 1) + '/' + str(NUM_EPOCHS) + 'th iteration') + model.train() + train_loss = 0.0 + val_loss = 0.0 + for raw, seg in train_loader: + #Batch index, raw data and segmented data + imgs, labels = raw.to(DEVICE), seg.to(DEVICE) #(12, 1, 256, 128) (12, 6, 256, 128) + #(Batch size, Channel number, Width, Height) + optimizer.zero_grad() + outputs = model(imgs) #(12, 6, 256, 128) #Segmented to 6 classes + loss = criteria(outputs, labels) + loss.backward() + optimizer.step() + train_loss += loss.item() + print('Passed all') + torch.save('model_state_dict': model.state_dict(), \ + 'optimizer_state_dict': optimizer.state_dict(), \ + f = '/home/Student/s4722435/miniconda3/envs/new_torch/unet_stuff/model.pth') + model.eval() + with torch.no_grad(): + for raw, seg in validation_loader: + imgs, labels = raw.to(DEVICE), seg.to(DEVICE) + outputs = model(imgs) + loss = criteria(outputs, labels) + loss.backward() + val_loss += loss.item() + torch.save(model.state_dict(), \ + f = '/home/Student/s4722435/miniconda3/envs/new_torch/unet_stuff/model.pth') + #Calculate average misclassification (per pixel) rate for train and validation sets. + train_loss_avg = train_loss / len(train_set) + val_loss_avg = val_loss / len(val_loss) + train_loss_list.append(train_loss_avg) + valid_loss_list.append(val_loss_avg) + + test_loss = 0.0 + dsc = 0.0 + model.eval() + print('Test set') + for raw, seg in test_loader: + imgs, labels = raw.to(DEVICE), seg.to(DEVICE) + outputs = model(imgs) + loss = criteria(outputs, labels) + loss.backward() + val_loss += loss.item() + test_loss += loss + test_loss_list.append(loss) + dsc = dice_similarity_coeff(outputs, labels) + test_loss = test_loss // len(test_set) + dsc = dsc // len(test_set) + + plot_line_graph(train_loss_list, xlabel = 'epoch', ylabel = 'Train Loss', filename = 'train_loss.png') + plot_line_graph(valid_loss_list, xlabel = 'epoch', ylabel = 'Validation Loss', filename = 'val_loss.png') + plot_line_graph(test_loss_list, xlabel = 'batch', ylabel = 'Test Loss', filename = 'test_loss.png') + print('Dice Similarity Coefficient', dsc) + torch.save(model.state_dict(), \ + f = '/home/Student/s4722435/miniconda3/envs/new_torch/unet_stuff/model.pth') + +def main2(): + train_loader, train_set = dst.keras_dataloader(IMAGE_SIZE, '_train/') + validation_loader, validation_set = dst.keras_dataloader(IMAGE_SIZE, '_validate/') + test_loader, test_set = dst.keras_dataloader(IMAGE_SIZE, '_test/') + + # For Visualization + train_loss_list = [] + valid_loss_list = [] + test_loss_list = [] + + #Define UNet model, loss function and update optimizer. + model = mm.UNet().to(DEVICE) + criterion = nn.BCELoss().to(DEVICE) + optimizer = optim.SGD(model.parameters(), weight_decay=WEIGHT_DECAY, lr = LEARNING_RATE, momentum=MOMENTUM) + + for epoch in range(NUM_EPOCHS): + print(f"{epoch}/{NUM_EPOCHS}th iteration") + model.train() + train_loss = 0.0 + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/unet_img.png b/unet_img.png new file mode 100644 index 000000000..0c0f0c554 Binary files /dev/null and b/unet_img.png differ