KNN accuracy implementation

I have implemented knn accuracy in tensorflow following this code which is written in pytorch. The aim is to determine how well my model is performing on image embeddings. My implementation is done using a callback, however I get the same score of 1.17% on every epoch. Could someone kindly point me in the right direction or what I am doing wrong. I have included the code for your reference. Thank you in advance.

class KNNMonitor(tf.keras.callbacks.Callback):
    def __init__(self,index,query,k=200,t=0.1):
        super(KNNMonitor,self).__init__()
        self.index = index
        self.query = query
        self.k = k
        self.t = t
        self.top_1 = 0.0
        self.total_num = 0
    def on_epoch_end(self,epoch,logs=None):
#         self.top_1 = 0.0
#         self.total_num = 0
        self.feature_bank = []
        self.index_labels = []
        for x,y in self.index:
            x = self.model.encoder(x)
            self.feature_bank.append(x)
            self.index_labels.append(y)
        self.feature_bank = tf.concat(self.feature_bank,axis=0) # output[N,2048]
        self.index_labels = tf.concat(self.index_labels,axis=0) #output[N]
        for qx,qy in self.query:
            qx = self.model.encoder(qx) # output[B,2048]
            self.total_num += qx.shape[0]
            cos_distance = tf.linalg.matmul(qx,self.feature_bank,transpose_b=True) # output[B,N]
            k_values,k_indices = tf.raw_ops.TopKV2(input=cos_distance,k=self.k) # output (B,K)
            sim_labels = tf.gather(self.index_labels,k_indices,axis=-1) # output(B,K)
            sim_weight = tf.math.exp(k_values/self.t) # output[B,K]
            
            # one-hot-encode topk labels
            one_hot_label = tf.zeros([qx.shape[0]*self.k,C.CLASSES]) # output[B*K,C]
            sl = tf.reshape(sim_labels,[1,-1]) # output[1,B*K]
            r = tf.cast(
                tf.reshape(tf.range(0,qx.shape[0]*self.k),[1,-1]),
                dtype=tf.int64
            ) # output[1,B*K]
            indices = tf.transpose(tf.concat([r,sl],axis=0)) # output[B*K,2]
            one_hot_label = tf.tensor_scatter_nd_update(
                one_hot_label,
                indices=indices,
                updates= tf.ones(qx.shape[0]*self.k)
            ) # output[B*K,C]
            
            pred_scores = tf.math.reduce_sum(
                tf.reshape(one_hot_label,[qx.shape[0],-1,C.CLASSES]) *
                tf.expand_dims(sim_weight,axis=-1),
                axis=1
            ) # output[B,C]
            pred_scores_indices = tf.argsort(pred_scores,axis=-1,direction='DESCENDING') # output[B,C]
            total = (pred_scores_indices[:, 0] == tf.cast(qy,dtype=tf.int32))
            self.top_1 +=  tf.math.reduce_sum(tf.cast(total,dtype=tf.float32)).numpy()
        print(f"knn accuracy:{self.top_1 / self.total_num *100:.2f}%")