Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e62675a
Add files via upload
ritikarana9999 Oct 20, 2023
7ffe75c
Created using Colaboratory
ritikarana9999 Oct 20, 2023
7715b99
Add files via upload
ritikarana9999 Oct 20, 2023
3392663
Add files via upload
ritikarana9999 Oct 20, 2023
2181156
Delete resources/README.md
ritikarana9999 Oct 20, 2023
e450da3
Delete README.md
ritikarana9999 Oct 20, 2023
146b5fd
Add files via upload
ritikarana9999 Oct 20, 2023
2a0f88a
Update README.md
ritikarana9999 Oct 20, 2023
cdd15fe
Add files via upload
ritikarana9999 Oct 20, 2023
d17b66e
Update README.md
ritikarana9999 Oct 20, 2023
113a76e
Update README.md
ritikarana9999 Oct 20, 2023
49fb582
Update README.md
ritikarana9999 Oct 20, 2023
f039683
Update README.md
ritikarana9999 Oct 20, 2023
7f01acc
Add files via upload
ritikarana9999 Oct 20, 2023
57098ec
Update README.md
ritikarana9999 Oct 20, 2023
f6ea203
Created using Colaboratory
ritikarana9999 Oct 20, 2023
6695000
Created using Colaboratory
ritikarana9999 Oct 21, 2023
4ae370f
Update README.md
ritikarana9999 Oct 21, 2023
b0d5410
Update README.md
ritikarana9999 Oct 21, 2023
8ae341d
Update README.md
ritikarana9999 Oct 21, 2023
4294127
Add files via upload
ritikarana9999 Oct 21, 2023
6917cbb
Update README.md
ritikarana9999 Oct 21, 2023
6d468ad
Created dataset.py
ritikarana9999 Nov 27, 2023
837c036
Creating parameters
ritikarana9999 Nov 28, 2023
026803b
Rename parameters to parameters.py
ritikarana9999 Nov 28, 2023
8db429b
Create modules.py
ritikarana9999 Nov 28, 2023
46645aa
Create predict.py
ritikarana9999 Nov 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,761 changes: 2,761 additions & 0 deletions ADNI_Brain_Visual_Transformer_47306725.ipynb

Large diffs are not rendered by default.

286 changes: 275 additions & 11 deletions README.md

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
def loadData():
"""
Loading the dataset.
"""
trainData = tf.keras.preprocessing.image_dataset_from_directory(
DATA_LOAD_DEST + "/train", labels='inferred', label_mode='binary',
image_size=[IMG_SIZE, IMG_SIZE], shuffle=True,
batch_size=BATCH_SIZE, seed=8, class_names=['AD', 'NC']
)

testData = tf.keras.preprocessing.image_dataset_from_directory(
DATA_LOAD_DEST + "/test", labels='inferred', label_mode='binary',
image_size=[IMG_SIZE, IMG_SIZE], shuffle=True,
batch_size=BATCH_SIZE, seed=8, class_names=['AD', 'NC']
)

# Augmenting data
normalize = tf.keras.layers.Normalization()
flip = tf.keras.layers.RandomFlip(mode='horizontal', seed=8)
rotate = tf.keras.layers.RandomRotation(factor=0.02, seed=8)
zoom = tf.keras.layers.RandomZoom(height_factor=0.1, width_factor=0.1, seed=8)

trainData = trainData.map(
lambda x, y: (rotate(flip(zoom(normalize(x)))), y)
)

testData = testData.map(
lambda x, y: (rotate(flip(zoom(normalize(x)))), y)
)

# Taking half of the 9000 images from the test set as validation data
validationData = testData.take(len(list(testData))//2)

# Using remaining images as test set
testData = testData.skip(len(list(testData))//2)

return trainData, validationData, testData

Binary file added google colab commit logs/Screenshot (136).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added google colab commit logs/Screenshot (137).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added google colab commit logs/Screenshot (138).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
216 changes: 216 additions & 0 deletions modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
class PatchLayer(Layer):
"""
Layering and transforming images into patches.
"""
def __init__(self, img_size, patch_size, num_patches, projection_dim, **kwargs):
super(PatchLayer, self).__init__(**kwargs)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.projection_dim = projection_dim
self.half_patch = patch_size // 2
self.flatten_patches = layers.Reshape((num_patches, -1))
self.projection = layers.Dense(units=projection_dim)
self.layer_norm = layers.LayerNormalization(epsilon=1e-6)

def shiftImg(self, images, mode):
# Building diagonally-shifted images
if mode == 'left-up':
cropheight = self.half_patch
cropwidth = self.half_patch
shiftheight = 0
shiftwidth = 0
elif mode == 'left-down':
cropheight = 0
cropwidth = self.half_patch
shiftheight = self.half_patch
shiftwidth = 0
elif mode == 'right-up':
cropheight = self.half_patch
cropwidth = 0
shiftheight = 0
shiftwidth = self.half_patch
else:
cropheight = 0
cropwidth = 0
shiftheight = self.half_patch
shiftwidth = self.half_patch

crop = tf.image.crop_to_bounding_box(
images,
offset_height=cropheight,
offset_width=cropwidth,
target_height=self.img_size - self.half_patch,
target_width=self.img_size - self.half_patch
)

shiftPad = tf.image.pad_to_bounding_box(
crop,
offset_height=shiftheight,
offset_width=shiftwidth,
target_height=self.img_size,
target_width=self.img_size
)
return shiftPad

def call(self, images):
images = tf.concat(
[
images,
self.shiftImg(images, mode='left-up'),
self.shiftImg(images, mode='left-down'),
self.shiftImg(images, mode='right-up'),
self.shiftImg(images, mode='right-down'),
],
axis=-1
)
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding='VALID'
)
flat_patches = self.flatten_patches(patches)
tokens = self.layer_norm(flat_patches)
tokens = self.projection(tokens)

return (tokens, patches)

def getConfig_(self):
config_ = super(PatchLayer, self).getConfig_()
config_.update(
{
'img_size': self.img_size,
'patch_size': self.patch_size,
'num_patches': self.num_patches,
'projection_dim': self.projection_dim
}
)
return config_

class Embed_Patch(Layer):
"""
Layering for projecting patches into a vector.
"""
def __init__(self, num_patches, projection_dim, **kwargs):
super(Embed_Patch, self).__init__(**kwargs)
self.num_patches = num_patches
self.projection_dim = projection_dim
self.position_embedding = layers.Embedding(
input_dim=self.num_patches, output_dim=projection_dim
)

def call(self, patches):
positions = tf.range(0, self.num_patches, delta=1)
return patches + self.position_embedding(positions)

def getConfig_(self):
config_ = super(Embed_Patch, self).getConfig_()
config_.update(
{
'num_patches': self.num_patches,
'projection_dim': self.projection_dim
}
)
return config_

class Multi_Head_AttentionLSA(layers.MultiHeadAttention):
"""
Multi Head Attention layer for the transformer-encoder block, but with the
addition of using Local-Self-Attention to improve feature learning.

"""
def __init__(self, **kwargs):
super(Multi_Head_AttentionLSA, self).__init__(**kwargs)
self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)

def computeAttention(self, query, key, value, attention_mask=None,
training=None):
query = tf.multiply(query, 1.0/self.tau)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_mask = tf.convert_to_tensor(attention_mask)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training
)
attention_output = tf.einsum(
self._combine_equation, attention_scores_dropout, value
)
return attention_output, attention_scores

def getConfig_(self):
config_ = super(Multi_Head_AttentionLSA, self).getConfig_()
return config_


def buildVisionTransformer(input_shape, img_size, patch_size, num_patches,
attention_heads, projection_dim, hidden_units, dropout_rate,
transf_layers, mlp_head_units):
"""
Building the vision transformer.
"""
# Input layer
inputs = layers.Input(shape=input_shape)

# Convert image data into patches
(tokens, _) = PatchLayer(
img_size,
patch_size,
num_patches,
projection_dim
)(inputs)

# Encode patches
encodedPatches = Embed_Patch(num_patches, projection_dim)(tokens)

# Create transformer layers
for _ in range(transf_layers):
# First layer normalisation
layerNorm1 = layers.LayerNormalization(
epsilon=1e-6
)(encodedPatches)

# Build diagoanl attention mask
diagAttnMask = 1 - tf.eye(num_patches)
diagAttnMask = tf.cast([diagAttnMask], dtype=tf.int8)

# Multi-head attention layer
attention_output = Multi_Head_AttentionLSA(
num_heads=attention_heads, key_dim=projection_dim,
dropout=dropout_rate
)(layerNorm1, layerNorm1, attention_mask=diagAttnMask)

# First skip connection
skip1 = layers.Add()([attention_output, encodedPatches])

# Second layer normalisation
layerNorm2 = layers.LayerNormalization(epsilon=1e-6)(skip1)

# Multi-Layer Perceptron
mlpLayer = layerNorm2
for units in hidden_units:
mlpLayer = layers.Dense(units, activation=tf.nn.gelu)(mlpLayer)
mlpLayer = layers.Dropout(dropout_rate)(mlpLayer, training=False)

# Second skip connection
encodedPatches = layers.Add()([mlpLayer, skip1])

# Create a [batch_size, projection_dim] tensor
representtn = layers.LayerNormalization(epsilon=1e-6)(encodedPatches)
representtn = layers.Flatten()(representtn)
representtn = layers.Dropout(dropout_rate)(representtn, training=False)

# MLP layer for learning features
features = representtn
for units in mlp_head_units:
features = layers.Dense(units, activation=tf.nn.gelu)(features)
features = layers.Dropout(dropout_rate)(features, training=False)

# Classify outputs
logits = layers.Dense(1)(features)

# Create Keras model
model = tf.keras.Model(inputs=inputs, outputs=logits)

return model
20 changes: 20 additions & 0 deletions parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Hyperparameters
IMG_SIZE = 128
PATCH_SIZE = 8
BATCH_SIZE = 32
EPOCHS = 10
WEIGHT_DECAY = 0.0001
PROJECTION_DIM = 512 # MLP-blocks depth
LEARN_RATE = 0.0005
TRANSF_LAYERS = 5 # No. of transformer-encoder-blocks
DROPOUT_RATE = 0.2
ATTENTION_HEADS = 5
MLP_HEAD_UNITS = [256, 128]
DATA_LOAD_DEST = "/content"
MODEL_SAVE_DEST = "/content/vision_transformer"


#Calculating automat
INPUT_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
HIDDEN_UNITS = [PROJECTION_DIM * 2, PROJECTION_DIM]
NUM_PATCHES = int((IMG_SIZE/PATCH_SIZE) ** 2)
47 changes: 47 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
def predict(load_path, testData):
optimizer = tfa.optimizers.AdamW(
learning_rate=LEARN_RATE,
weight_decay=WEIGHT_DECAY
)

model = tf.keras.models.load_model(
load_path,
custom_objects={
'PatchLayer': PatchLayer,
'Embed_Patch': Embed_Patch,
'MultiheadattentionLSA': Multi_Head_AttentionLSA,
'AdamW': optimizer
}
)

model.evaluate(testData)

# Plot confusion matrix
y_true = []
y_pred = []

for image_batch, label_batch in testData:
y_true.append(label_batch)
y_pred.append((model.predict(image_batch, verbose=0) > 0.5).astype('int32'))

labels_true = tf.concat([tf.cast(item[0], tf.int32) for item in y_true], axis=0)
labels_pred = tf.concat([item[0] for item in y_pred], axis=0)

matrix = tf.math.confusion_matrix(labels_true, labels_pred, 2).numpy()

fig, ax = plt.subplots(figsize=(8,8))
ax.matshow(matrix, cmap=plt.cm.Blues, alpha=0.3)
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
ax.text(x=j, y=i, s=matrix[i, j], va='center', ha='center', size='xx-large')

plt.xlabel('Predictions', fontsize=18)
plt.ylabel('Actual Label', fontsize=18)
plt.suptitle('Confusion Matrix', fontsize=18)
plt.savefig('confusion_matrix')
plt.clf()


if __name__ == '__main__':
train, val, test = loadData()
predict(MODEL_SAVE_DEST, test)
Binary file added resources/accuracy (1).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/accuracy (2).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/accuracy (3).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/accuracy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/brain1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/brainl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/brainld.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/brainr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/brainrd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/confusion_matrix (1).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/confusion_matrix (2).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/localitattention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/losses (1).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/losses (2).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/losses (3).png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/losses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/transformer_block.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/vt.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.