diff --git a/recognition/README.md b/recognition/README.md index 5c646231c..a52c3600a 100644 --- a/recognition/README.md +++ b/recognition/README.md @@ -1,10 +1,73 @@ -# Recognition Tasks -Various recognition tasks solved in deep learning frameworks. - -Tasks may include: -* Image Segmentation -* Object detection -* Graph node classification -* Image super resolution -* Disease classification -* Generative modelling with StyleGAN and Stable Diffusion +# ISIC Lesion Segmentation Algorithm + +## Description +The ISIC Lesion Segmentation Algorithm was designed to automatically segment skin lesion boundaries from dermatoscopic images. Early detection of malignant skin lesions is crucial in improving the prognosis of skin cancers such as melanoma. The algorithm operates by analyzing input images and leverages a convolutional neural network (CNN) to identify and segment potential skin lesions, distinguishing them from healthy skin. The Dice similarity coefficient is used to compare the algorithm output to the ground truth reference mask, which essentially measures the proportion of output pixels that match the true image. + +The model is a modified UNet and is composed of several CNN layers, skip connections, and uses deep supervision facilitated by segmentation layers that connect different levels of the network to the final output. The architecture of the model was inspired by the [improved UNet](https://arxiv.org/abs/1802.10508v1) (Figure 1), which proved to be an effective 3D brain tumor segmentation model during the BRATS 2017 challenge. The network is trained using the 2018 ISIC (International Skin Imaging Collaboration) dataset, which contains annotated images of various skin lesions. + +![Image of the improved UNet architecture](./UNet_Segmentation_s4745275/images/Figure_1.png) +Figure 1: Improved UNet architecture. Designed by F. Isensee et al. + +## Dependencies + +To run the ISIC Lesion Segmentation Algorithm, you'll need the following libraries: + +- Python (only verified for 3.7+) +- numpy: For numerical computations and some tensor operations +- PyTorch: For building and training the neural network +- matplotlib: For plotting and visualisation +- PIL (Pillow): For loading the dataset and visualisation + +To install any dependencies you can use `pip install` + +## Reproducability + +To run the algorithm and reproduce the results I've obtained, please be aware of the following considerations: + +1. Directory Paths for ISICDataset: The paths specified when initializing ISICDataset may need to be modified to match the directory structure on your machine. Ensure that you point it to the correct location where your dataset resides. + +2. Model State Dictionary Directory: The directory where the model state dictionary is saved/loaded may differ based on your setup. Adjust the path accordingly to ensure the algorithm can access the model or save it correctly. + +Always ensure that you have the necessary permissions to read/write in the specified directories and that the paths are correctly formatted. + +## Usage +#### See predict.py for a full usage demonstration of the model. +### Input +torch.Tensor with shape [batch_size, 6, 256, 256] +- The batch_size denotes the number of inputted images, this is the only argument that varies +- 6 channels (3 for RGB and 3 for HSV) +- The image has dimensions 256x256 + +### Output + +torch.Tensor was shape [batch_size, 1, 256, 256] +- The batch_size denotes the number of inputted images, this is the only argument that varies +- 1 channel containing probabilities of being +- The image has dimensions 256x256 + +## Results +Ultimately, after extensive training over 50 epochs, the model attained an average Dice similarity coefficient of 0.7364 on the test set. This performance indicates potential areas for improvement. Given more time, I would delve into techniques like hyperparameter tuning and possibly experiment with alternative optimizers. + +![Beautiful demonstation of the model efficacy](./UNet_Segmentation_s4745275/images/Figure_2.png) +Figure 2: An example output from a random sample. Black indicates non-lesion, white indicates lesion. (25 epochs) + +That said, the model does exhibit proficiency in segmenting the image. This is evident in Figure 2, where the output mask closely mirrors the true mask, especially around the edges. + +## Pre-processing +Various transformation pipelines were implemented for both pre-processing and data augmentation. You can find these in the dataset.py file. They serve to convert the provided images or masks into tensors compatible with the model (refer to the Input and Output section), as well as to normalize the inputs. During training, the process_and_augment pipeline was employed, performing random scalings, flips, rotations, and more to enhance the model's generalizability during learning. + + +# Data Splits + +The data was partitioned as follows: + +- Training: 70% +- Validation: 20% +- Testing: 10% + +With this configuration, a significant majority (70%) of the data is allocated for training. Deep learning models, like the UNet I implemented, require a robust volume of data for effective training. By dedicating a larger segment of the dataset to training, the model can encounter a more diverse array of samples, which is essential for discerning and internalizing underlying patterns. Given the dataset's substantial size (over 2500 samples), allocating 70% to training felt appropriate. + +The validation set serves a dual purpose: it allows for ongoing evaluation during training and aids in determining when to cease training — a tactic known as early stopping — to mitigate overfitting. A generous validation set is imperative to ensure that the decision to halt training is anchored in a trustworthy performance metric rather than the inconsistencies of a smaller subset. + +Finally, the test set offers an objective assessment of the model's performance post-training. While 10% might seem modest, given the dataset's magnitude, it still yields a significant number of samples. Consequently, the test set furnishes a dependable measure of how the model is likely to perform in real-world scenarios. +The data was divided as follows: diff --git a/recognition/UNet_Segmentation_s4745275/best_model.pth b/recognition/UNet_Segmentation_s4745275/best_model.pth new file mode 100644 index 000000000..e69de29bb diff --git a/recognition/UNet_Segmentation_s4745275/dataset.py b/recognition/UNet_Segmentation_s4745275/dataset.py new file mode 100644 index 000000000..8526847b5 --- /dev/null +++ b/recognition/UNet_Segmentation_s4745275/dataset.py @@ -0,0 +1,158 @@ +""" +File containing the data loaders used for loading and preprocessing the data. +""" + +import os +import torch +from utils import RandomCenterCrop, RandomRotate90, DictTransform +from torch.utils.data import Dataset +from torchvision import transforms +import torchvision.transforms.functional as TF +from PIL import Image +import numpy as np + +# These are the default paths for me, they may not apply to you. Modify as required +image_path = "/home/groups/comp3710/ISIC2018/ISIC2018_Task1-2_Training_Input_x2" +mask_path = "/home/groups/comp3710/ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2" +inconsistent_path = "/home/Student/s4745275/PatternAnalysis-2023/recognition/UNet_Segmentation_s4745275/inconsistent_ids.txt" + + +def check_consistency( + image_dir=image_path, mask_dir=mask_path, inconsistent_path=inconsistent_path +): + image_ids = { + img.split(".")[0] for img in os.listdir(image_dir) if img.endswith(".jpg") + } + mask_ids = { + mask.split("_segmentation.")[0] + for mask in os.listdir(mask_dir) + if mask.endswith("_segmentation.png") + } + + # Using list differences to find inconsistencies + images_without_masks = image_ids - mask_ids + masks_without_images = mask_ids - image_ids + + if images_without_masks or masks_without_images: + inconsistent_ids = images_without_masks.union(masks_without_images) + # Save to a file + with open(inconsistent_path, "w") as file: + for ID in inconsistent_ids: + file.write(f"{ID}\n") + + print(f"Detected {len(inconsistent_ids)} inconsistencies, fixed em tho") + + +class ISICDataset(Dataset): + def __init__( + self, + transform, + image_dir=image_path, + mask_dir=mask_path, + inconsistent_path=inconsistent_path, + ): + # Load the inconsistent IDs + with open(inconsistent_path, "r") as file: + excluded_ids = set(line.strip() for line in file) + + self.image_dir = image_dir + self.mask_dir = mask_dir + self.image_ids = [ + img.split(".")[0] + for img in os.listdir(image_dir) + if img.endswith(".jpg") and img.split(".")[0] not in excluded_ids + ] + self.transform = transform + + def __len__(self): + return len(self.image_ids) + + def handle_inconsistency(self): + images_without_masks, masks_without_images = check_consistency( + self.image_dir, self.mask_dir + ) + inconsistent_ids = images_without_masks.union(masks_without_images) + + # Save to a file + with open(inconsistent_path, "a") as file: # 'a' mode for appending + for ID in inconsistent_ids: + file.write(f"{ID}\n") + + def __getitem__(self, idx): + img_name = os.path.join(self.image_dir, self.image_ids[idx] + ".jpg") + mask_name = os.path.join( + self.mask_dir, self.image_ids[idx] + "_segmentation.png" + ) + + try: + with Image.open(img_name) as image, Image.open(mask_name) as mask: + image = image.convert("RGB") + mask = mask.convert("L") + sample = {"image": image, "mask": mask} + + if self.transform: + sample = self.transform(sample) + + # Convert mask to binary 0/1 tensor + sample["mask"] = (torch.tensor(np.array(sample["mask"])) > 0.5).float() + + return sample["image"], sample["mask"] + + except FileNotFoundError: + self.handle_inconsistency() + return self.__getitem__(idx) + + +pre_process_image = transforms.Compose( + [ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda( + lambda img_tensor: torch.cat( + [ + img_tensor, + TF.to_tensor(TF.to_pil_image(img_tensor).convert("HSV")), + ], + dim=0, + ) + ), + transforms.Normalize( + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + ), + ] +) + +pre_process_mask = transforms.Compose( + [transforms.Resize((256, 256)), transforms.ToTensor()] +) + + +# Transformation pipeline to pre-process and augment the dataset +process_and_augment = transforms.Compose( + [ + RandomRotate90(), + RandomCenterCrop(), + DictTransform(transforms.RandomHorizontalFlip()), + DictTransform(transforms.RandomVerticalFlip()), + DictTransform(transforms.Resize((256, 256))), + DictTransform(transforms.ToTensor()), + DictTransform( + transforms.Lambda( + lambda img_tensor: torch.cat( + [ + img_tensor, + TF.to_tensor(TF.to_pil_image(img_tensor).convert("HSV")), + ], + dim=0, + ) + ), + False, + ), + DictTransform( + transforms.Normalize( + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + ), + False, + ), + ] +) diff --git a/recognition/UNet_Segmentation_s4745275/images/Figure_1.png b/recognition/UNet_Segmentation_s4745275/images/Figure_1.png new file mode 100644 index 000000000..3bca31e8a Binary files /dev/null and b/recognition/UNet_Segmentation_s4745275/images/Figure_1.png differ diff --git a/recognition/UNet_Segmentation_s4745275/images/Figure_2.png b/recognition/UNet_Segmentation_s4745275/images/Figure_2.png new file mode 100644 index 000000000..aa0225ef4 Binary files /dev/null and b/recognition/UNet_Segmentation_s4745275/images/Figure_2.png differ diff --git a/recognition/UNet_Segmentation_s4745275/images/ISIC_0000000.jpg b/recognition/UNet_Segmentation_s4745275/images/ISIC_0000000.jpg new file mode 100644 index 000000000..0f8e21eb2 Binary files /dev/null and b/recognition/UNet_Segmentation_s4745275/images/ISIC_0000000.jpg differ diff --git a/recognition/UNet_Segmentation_s4745275/images/ISIC_0000000_segmentation.png b/recognition/UNet_Segmentation_s4745275/images/ISIC_0000000_segmentation.png new file mode 100644 index 000000000..caa4c0a4d Binary files /dev/null and b/recognition/UNet_Segmentation_s4745275/images/ISIC_0000000_segmentation.png differ diff --git a/recognition/UNet_Segmentation_s4745275/images/Training_evolution.png b/recognition/UNet_Segmentation_s4745275/images/Training_evolution.png new file mode 100644 index 000000000..0d34533c2 Binary files /dev/null and b/recognition/UNet_Segmentation_s4745275/images/Training_evolution.png differ diff --git a/recognition/UNet_Segmentation_s4745275/modules.py b/recognition/UNet_Segmentation_s4745275/modules.py new file mode 100644 index 000000000..6c1ff6968 --- /dev/null +++ b/recognition/UNet_Segmentation_s4745275/modules.py @@ -0,0 +1,179 @@ +""" +Contains the source code of the components in my model. Each component is implemented as a class or function. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Context(nn.Module): + """ + Encoder that computes the activations in the context pathway. This class behaves as the 'context module' from the paper. + Each Context module is a pre-activation residual block with two 3x3x3 convolutional layers and a dropout layer (p = 0.3) in between. + Instance normalization and leaky ReLU is used throughout the network. + """ + + def __init__(self, in_channels): + super(Context, self).__init__() + # 3x3 convolutional layer that preserves the input size + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + # Dropout layer with p_drop = 0.3 + self.dropout = nn.Dropout(p=0.3) + # Instance normalization of the input is used + self.norm = nn.InstanceNorm2d(in_channels) + + def forward(self, x): + # Keep track of the initial input + shortcut = x + # First convolutional layer + x = F.leaky_relu(self.norm(x), negative_slope=1e-2) + x = self.conv(x) + # Dropout layer + x = self.dropout(x) + # Second convolution layer + x = F.leaky_relu(self.norm(x), negative_slope=1e-2) + x = self.conv(x) + # Return the residual output + return x + shortcut + + +class Upsampling(nn.Module): + """ + Upsampling module used to tranfer information from low resolution feature maps into high resolution fearure maps. + We use a simple upscale that repeats the feature voxels twice in each spatial dimension, followed by a 3x3x3 convolution + that halves the number of feature maps. Instance normalization and leaky ReLU is used throughout the network. + """ + + def __init__(self, in_channels): + super(Upsampling, self).__init__() + # Upsamping components: + self.up_norm = nn.InstanceNorm2d(in_channels) + self.up_conv = nn.Conv2d( + in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1 + ) + + # Localisation components: + self.merged_norm = nn.InstanceNorm2d(in_channels) + self.merged_conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + self.half_norm = nn.InstanceNorm2d(in_channels // 2) + self.half_conv = nn.Conv2d( + in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, skip_features, local: bool = True): + """ + Forward pass of the Upsampling module, with an optional localisation step. + + Parameters: + - x (torch.Tensor): Input tensor, representing the activations from a deeper layer. + - skip_features (torch.Tensor): Activations from the corresponding encoder layer, for the skip connection. + - local (bool, optional): If True, the localisation operations are applied after upsampling. Default is True. + + Returns: + - torch.Tensor: Upsampled (and optionally localised) feature map. + + The function first upsamples the input tensor 'x' and then concatenates the result with the 'skip_features'. + If 'local' is True, it subsequently applies the localisation steps to refine the feature maps further. + """ + # Upsampling Module + upsampled = F.interpolate(x, scale_factor=2, mode="nearest") + upsampled = F.leaky_relu(self.up_norm(upsampled), negative_slope=1e-2) + upsampled = self.up_conv(upsampled) + + # Concatenate upsampled features with context features + merged = torch.cat([upsampled, skip_features], dim=1) + + if local is False: + return merged + + # Localisation Module + localised = F.leaky_relu(self.merged_norm(merged), negative_slope=1e-2) + localised = self.merged_conv(localised) # First convolutional layer + + localised = F.leaky_relu(self.half_norm(localised), negative_slope=1e-2) + localised = self.half_conv(localised) # Second convolutional layer + + return localised + + +class Segmentation(nn.Module): + def __init__(self, in_channels, num_classes): + super(Segmentation, self).__init__() + self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1, stride=1) + + def forward(self, x, other_layer, upscale=True): + x = self.conv(x) + if other_layer is not None: + x += other_layer + if not upscale: + return x + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class UNet(nn.Module): + def __init__(self, in_channels, num_classes=1): + super(UNet, self).__init__() + + # Context modules + self.context1 = Context(16) + self.context2 = Context(32) + self.context3 = Context(64) + self.context4 = Context(128) + self.context5 = Context(256) + + # Upsampling modules + self.up1 = Upsampling(256) + self.up2 = Upsampling(128) + self.up3 = Upsampling(64) + self.up4 = Upsampling(32) + + # Convolutional layer used on input channel + self.input_conv = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1) + + # Convolutional layers that connect context modules + self.conv1 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) + self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) + self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + # Convolutional layer for the localisation pathway output + self.end_conv = nn.Conv2d(32, 32, kernel_size=1) + + # Segmentation layers + self.segment1 = Segmentation(64, num_classes) + self.segment2 = Segmentation(32, num_classes) + self.segment3 = Segmentation(32, num_classes) + + def forward(self, x): + # Context Pathway + cntx_1 = self.context1(self.input_conv(x)) + cntx_2 = self.context2(self.conv1(cntx_1)) + cntx_3 = self.context3(self.conv2(cntx_2)) + cntx_4 = self.context4(self.conv3(cntx_3)) + cntx_5 = self.context5(self.conv4(cntx_4)) + + # Localization Pathway + local_1 = self.up1(cntx_5, cntx_4) + local_2 = self.up2(local_1, cntx_3) + local_3 = self.up3(local_2, cntx_2) + local_out = self.end_conv(self.up4(local_3, cntx_1, False)) + + # Deep Supervision + seg_1 = self.segment1(local_2, None) + seg_2 = self.segment2(local_3, seg_1) + seg_3 = self.segment3(local_out, seg_2, upscale=False) + + # Apply sigmoid (paper used softmax, but this is binary) and return + return torch.sigmoid(seg_3) + # if num_classes > 2: return F.softmax(seg_3, dim=1) + +# Lil testing +model = UNet(6, 1) +test = torch.randn(1, 6, 32, 32) +output = model(test) +print(output.size()) \ No newline at end of file diff --git a/recognition/UNet_Segmentation_s4745275/predict.py b/recognition/UNet_Segmentation_s4745275/predict.py new file mode 100644 index 000000000..f2da760b0 --- /dev/null +++ b/recognition/UNet_Segmentation_s4745275/predict.py @@ -0,0 +1,62 @@ +"""Showing an example useage of the trained model using an image in the images folder. We also compare the prediction to the true mask""" + +import os +import matplotlib as plt +from matplotlib import pyplot +from modules import UNet +from dataset import pre_process_image, pre_process_mask +from PIL import Image +import torch + +# If available, its favourable to use the model on a GPU device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Example image and the corrsponding mask is located in the images directory +example_path = "recognition/UNet_Segmentation_s4745275/images/ISIC_0000000.jpg" +mask_path = ( + "recognition/UNet_Segmentation_s4745275/images/ISIC_0000000_segmentation.png" +) + +# Trained model weights and parameters are stored in best_model.pth +if os.path.exists("recognition/UNet_Segmentation_s4745275/best_model.pth"): + model_path = "recognition/UNet_Segmentation_s4745275/best_model.pth" +else: + # Execute the train.py file, which will create best_model.pth + os.system("python train.py") + +# Load an instance of the trained model +model = UNet(in_channels=6, num_classes=1) +# model.load_state_dict(torch.load(model_path)) +model = model.to(device) + +# Load the input image in RGB mode +image = Image.open(example_path).convert("RGB") +# Pre-process the input image +image = ( + pre_process_image(image).unsqueeze(0).to(device) +) # Add a batch dimension and send to device + +with torch.no_grad(): + prediction = model(image) # This is the prediction of the algorithm + prediction = (prediction > 0.5).float() # Binarize the output + + +# Visual comparison of the predicted segment to the true segment: + +# Convert output tensor to numpy array for visualization +predicted_np = prediction.squeeze().cpu().numpy() +fig, ax = pyplot.subplots(1, 2, figsize=(10, 5)) + +# Open and process the correct mask for visualization +mask = Image.open(mask_path).convert("L") +mask = pre_process_mask(mask) + +# True Mask +ax[0].imshow(mask.numpy().transpose(1, 2, 0)) +ax[0].set_title("True Mask") + +# Predicted Mask +ax[1].imshow(predicted_np, cmap="gray") +ax[1].set_title("Predicted Mask") + +plt.show() diff --git a/recognition/UNet_Segmentation_s4745275/slurm.sh b/recognition/UNet_Segmentation_s4745275/slurm.sh new file mode 100644 index 000000000..08f4d8854 --- /dev/null +++ b/recognition/UNet_Segmentation_s4745275/slurm.sh @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --gres=gpu:1 +#SBATCH --partition=vgpu +#SBATCH --job-name="COMP3710 Report" +#SBATCH --mail-user=s4745275@student.uq.edu.au +#SBATCH --mail-type=ALL +#SBATCH -e test_err.txt +#SBATCH -o test_out.txt + +source /home/Student/s4745275/miniconda/bin/activate /home/Student/s4745275/my_demo_environment +python train.py diff --git a/recognition/UNet_Segmentation_s4745275/train.py b/recognition/UNet_Segmentation_s4745275/train.py new file mode 100644 index 000000000..86f567531 --- /dev/null +++ b/recognition/UNet_Segmentation_s4745275/train.py @@ -0,0 +1,162 @@ +""" +Contains the source code for training, validating, testing and saving my model. +The model should be imported from “modules.py” and the data loader should be imported from “dataset.py”. +Make sure to plot the losses and metrics during training +""" + + +from dataset import ( + ISICDataset, + process_and_augment, + check_consistency, +) +import numpy as np +from modules import UNet +from torch.utils.data import DataLoader +from torch.utils.data.dataset import random_split +from utils import dice_loss, dice_coefficient +import torch +import torch.optim as optim +import torch.nn.functional as F + +# Set to True if you encounter dataset inconsistencies +check = False + +# Hyper-parameters +num_epochs = 25 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Method used during validation and testing to evaulate the model +def evaluate_model(model, data_loader, device): + # set the model to evaluation mode + model.eval() + + # List to store individual dice scores for each sample + all_dice_scores = [] + + # No gradient computation during evaluation + with torch.no_grad(): + for images, masks in data_loader: + images, masks = images.to(device), masks.to(device) + + # Compute predictions + outputs = model(images) + + # Convert predictions to binary using threshold + outputs = (outputs > 0.5).float() + + # Compute and store the dice coefficients + batch_dice_scores = dice_coefficient(outputs, masks) + all_dice_scores.extend(batch_dice_scores.cpu().numpy()) + + avg_dice_score = np.mean(all_dice_scores) + min_dice_score = np.min(all_dice_scores) + + return avg_dice_score, min_dice_score + + +# Check if the dataset is consistent +if check: + check_consistency() + + +# Loading up the dataset and applying custom augmentations +dataset = ISICDataset(process_and_augment) + +total_size = len(dataset) + +# Splitting into training, validation and testing sets +train_size = int(total_size * 0.7) +val_size = int(total_size * 0.2) +test_size = total_size - train_size - val_size + +train_dataset, val_dataset, test_dataset = random_split( + dataset, [train_size, val_size, test_size] +) + +# Data loaders for training, validation and testing +train_loader = DataLoader(train_dataset, 32, True) +validation_loader = DataLoader(val_dataset, 100, False) +test_loader = DataLoader(test_dataset, 100, False) + +# Creating an instance of the UNet to be trained +model = UNet(in_channels=6, num_classes=1) +model = model.to(device) + +# Setup the optimizer +optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5) +scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.985) + +print("Training loop:") + +# Variables used for training and validation: +running_loss = 0.0 +no_improvement = 0 +print_every = 20 # Print training feedback every 20 batches +best_val_similarity = 0.0 +best_val_minimum = 0.0 +patience = 8 # Number of epochs to wait before stopping + +for epoch in range(num_epochs): + # Set the model to training mode + model.train() + + for i, (images, masks) in enumerate(train_loader, 1): + # Move the data onto the device + images, masks = images.to(device), masks.to(device) + + # Zero the parameter gradients + optimizer.zero_grad() + + # Forward pass + outputs = model(images) + + loss = dice_loss(outputs, masks) + + # Backward pass + optimization + loss.backward() + optimizer.step() + + # Keep track of the running loss for testing feedback + running_loss += loss.item() + if i % print_every == 0: # Print every `print_every` batches + print( + f"Epoch {epoch + 1}, Batch {i}: Loss = {running_loss / print_every:.4f}" + ) + running_loss = 0.0 + + # Evaluate the model using the validation set + dice_similarity, dice_minimum = evaluate_model(model, validation_loader, device) + val_loss = 1 - dice_similarity + # Print out the validation metrics + print( + f"Validation metrics during epoch {epoch + 1}, loss = {val_loss:.4f}, similarity = {dice_similarity:.4f}, minimum = {dice_minimum:.4f}" + ) + + # Model checkpointing + if dice_similarity > best_val_similarity and dice_minimum >= best_val_minimum: + best_val_loss = val_loss + torch.save(model.state_dict(), "best_model.pth") + no_improvement = 0 + else: + no_improvement += 1 + + # Early stoppage if the model hasn't improved in `patience` epochs + if no_improvement > patience: + break + + scheduler.step() # Adjust learning rate + + +print("training complete") + +# Save the model +torch.save( + model.state_dict(), + "/home/Student/s4745275/PatternAnalysis-2023/recognition/Problem_47452752/best_model.pth", +) + +avg_dice_score, min_dice_score = evaluate_model(model, test_loader, device) + +print(f"Average Dice Coefficient on Test Set: {avg_dice_score:.4f}") +print(f"Minimum Dice Coefficient on Test Set: {min_dice_score:.4f}") diff --git a/recognition/UNet_Segmentation_s4745275/utils.py b/recognition/UNet_Segmentation_s4745275/utils.py new file mode 100644 index 000000000..46c4beb47 --- /dev/null +++ b/recognition/UNet_Segmentation_s4745275/utils.py @@ -0,0 +1,111 @@ +"""Contains any helper methods or classes""" +import torch +from random import choice +import torchvision.transforms.functional as TF +import torch.nn as nn + + +class RandomRotate90: + """Randomly rotates the image by 90, 180, or 270 degrees.""" + + def __init__(self, p=1.0): + self.p = p + + def __call__(self, sample): + image, mask = sample["image"], sample["mask"] + if torch.rand(1) < self.p: + degrees = [90, 180, 270] + angle = choice(degrees) + image = TF.rotate(image, angle) + mask = TF.rotate(mask, angle) + return {"image": image, "mask": mask} + + +class RandomCenterCrop: + """Randomly crops the center of the image by 80% or 70%.""" + + def __init__(self, scales=[0.8, 0.7], p=1.0): + self.scales = scales + self.p = p + + def __call__(self, sample): + image, mask = sample["image"], sample["mask"] + if torch.rand(1) < self.p: + scale = choice(self.scales) + image = TF.center_crop( + image, (int(image.height * scale), int(image.width * scale)) + ) + mask = TF.center_crop( + mask, (int(mask.height * scale), int(mask.width * scale)) + ) + return {"image": image, "mask": mask} + + +class DictTransform: + def __init__(self, transform, transform_mask=True): + self.transform = transform + self.transform_mask = transform_mask + + def __call__(self, sample): + image, mask = sample["image"], sample["mask"] + if self.transform_mask: + return { + "image": self.transform(image), + "mask": self.transform(mask), + } + return { + "image": self.transform(image), + "mask": mask, + } + + +# max_dice_loss = max(dice_losses) # Penalize the worst performance +# dice_loss = sum(dice_losses) / len(dice_losses) +def dice_loss(predicted, target, epsilon=1e-7): + # Compute dice coefficient for each image in the batch + dice_scores = dice_coefficient(predicted, target) + # Compute dice loss for each image in the batch + dice_losses = 1.0 - dice_scores + # Penalize any images with dice score less than 0.8 + penalized_losses = torch.where(dice_scores < 0.8, dice_losses * 2, dice_losses) + # Return the average loss + average_penalized_loss = penalized_losses.mean() + + return average_penalized_loss + + +def dice_coefficient( + predicted: torch.Tensor, target: torch.Tensor, epsilon=1e-7 +) -> torch.Tensor: + """Compute dice coefficient for each image in the batch""" + predicted = predicted.contiguous().view(predicted.shape[0], -1) + target = target.contiguous().view(target.shape[0], -1) + + intersection = (predicted * target).sum(dim=1) + return (2.0 * intersection + epsilon) / ( + predicted.sum(dim=1) + target.sum(dim=1) + epsilon + ) + + +def general_dice_loss(predicted, target): + # One-hot encode the target segmentation map + target_one_hot = torch.zeros_like(predicted) + for k in range(target_one_hot.shape[1]): + target_one_hot[:, k] = target == k + + # Compute the Dice loss for each class, then average + intersection = (predicted * target_one_hot).sum(dim=(2, 3)) + union = (predicted + target_one_hot).sum(dim=(2, 3)) + + dice_scores = 2 * intersection / union + loss = 1 - dice_scores.mean() + + return loss + + +# lil testing +# pred = torch.randn(3, 6, 32, 32) +# tar = torch.randn(3, 6, 32, 32) +# x = dice_coefficient(pred, tar) +# y = dice_loss(pred, tar) +#