Help - Triplett Loss not working properly

I am working on a Siamese network for image similarity calculation. I have taken inspiration from this example in the Keras documentation which worked but when i tried adapting it to my own problem which requires training resnet from scratch i experienced the loss nearing zero almost instantly when training without a margin and it staying at the loss over the entire training process when training with margin.

I tried several things, removing the attention layer, adding avg pooling, changing the learning rate, changing the margin.

I also did some additional testing, i let the model train for some time and then inspected the embeddings it created which had a variance of 0.0 when testing it like this:

variance = np.var(anchor_em, axis=0)
print("Variance of each dimension:", variance)
print("Mean variance:", np.mean(variance))

Variance of each dimension: [0. 0. 0. ... 0. 0. 0.]
Mean variance: 0.0

I also calculated the cosine similarity:

cos_sim_anchor_positive = tf.keras.losses.cosine_similarity(anchor_em, positive_em)
cos_sim_anchor_negative = tf.keras.losses.cosine_similarity(anchor_em, negative_em)
cos_sim_positive_negative = tf.keras.losses.cosine_similarity(positive_em, negative_em)

Cosine Similarity (Anchor, Positive): [-0.97650385]
Cosine Similarity (Anchor, Negative): [-0.9831804]
Cosine Similarity (Positive, Negative): [-0.99779606]

The images in my dataset all have been processed like this:

image_string =
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)

This is the full model code:

import keras
import tensorflow as tf
from keras import applications
from keras import layers
from keras import losses
from keras import optimizers
from keras import metrics
from keras import Model
from keras.applications import resnet
from keras.layers import Layer, Input
from tensorflow.keras.layers import Layer, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dense, Multiply, Reshape

target_shape = (224,224)

# Attention Layer

class LocalAttentionLayer(Layer):
def __init__(self):
super(LocalAttentionLayer, self).__init__()
\# Simplified Mini-CNN layers for attention map generation
self.attention_conv = Conv2D(filters=32, kernel_size=3, padding="same", activation="relu")  # Reduced number of filters
self.attention_pool = MaxPooling2D(pool_size=4)  # Increased pool size
self.attention_global_pool = GlobalAveragePooling2D()
self.attention_dense = Dense(units=224 \* 224, activation="sigmoid")  # Adjusted units

    def call(self, inputs):
        # Generate attention map
        x = self.attention_conv(inputs)
        x = self.attention_pool(x)
        x = self.attention_global_pool(x)
        attention_weights = self.attention_dense(x)
        # Reshape attention weights to match the input spatial dimensions
        attention_weights = Reshape((224, 224, 1))(attention_weights)
        attention_weights = tf.tile(attention_weights, [1, 1, 1, 3])  # Duplicate across the channel dimension
        # Apply attention weights
        attended_features = Multiply()([inputs, attention_weights])
        return attended_features

# Base Model Declaration

base_cnn = resnet.ResNet50(weights=None, input_shape=target_shape + (3,), include_top=False)

input = Input(target_shape + (3,))
x = LocalAttentionLayer()(input)
output = base_cnn(x)

embedding = Model(inputs=input, outputs=output, name="Embedding")

# Siamese Model

anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))

anchor_em = embedding(resnet.preprocess_input(anchor_input))
positive_em = embedding(resnet.preprocess_input(positive_input))
negative_em = embedding(resnet.preprocess_input(negative_input))

siamese_network = Model(
inputs=[anchor_input, positive_input, negative_input], outputs=[anchor_em, positive_em, negative_em]

# Training Loop

class SiameseModel(Model):

    def __init__(self, siamese_network, margin=0.2):
        self.siamese_network = siamese_network
        self.margin = margin
        self.loss_tracker = metrics.Mean(name="loss")
    def call(self, inputs):
        return self.siamese_network(inputs)
    def train_step(self, data):
        with tf.GradientTape() as tape:
            loss = self._compute_loss(data)
        gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
        # Applying the gradients on the model using the specified optimizer
            zip(gradients, self.siamese_network.trainable_weights)
        return {"loss": self.loss_tracker.result()}
    def test_step(self, data):
        loss = self._compute_loss(data)
        return {"loss": self.loss_tracker.result()}
    def _compute_loss(self, data):
        anchor, positive, negative= self.siamese_network(data)
        pos_dist = tf.reduce_sum(tf.square(anchor - positive), axis=-1)
        neg_dist = tf.reduce_sum(tf.square(anchor - negative), axis=-1)
        basic_loss = pos_dist - neg_dist + self.margin
        loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0))
        return loss
    def metrics(self):
        return [self.loss_tracker]

siamese_model = SiameseModel(siamese_network)
siamese_model.compile(optimizer=optimizers.Adam(0.0001)), epochs=10, validation_data=val_dataset)

What I have tried has been described above already.
What I expected to happen is the model learning to create meaningful embeddings for the Images.
Note: The similarity calculation should not be visual similarity but stylistic similarity. (artist recognition) so the paired images are not the same in a visual sense but in a stylistical sense.