How to modify an embedding directly in tensorflow distributed training

My model has some learnable embeddings with shapes (N, D). I use parameter-worker distribution architecture to train the model.

In each training step, only a part of the embeddings will be updated during back-propagation. I want to reset those unused (not updated) embeddings after back-propagation. How should I implement it? If I perform it in the build_model process without any gradient, will it be changed simultaneously in the parameter servers and workers? Theoretically, I think it should be implemented in parameter servers after parameter updating, but I have no idea how to implement it.

My training code is as follows, does anyone know where should I insert code to modify those unused embeddings?

import tensorflow as tf

ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# start ps or worker 
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

if FLAGS.job_name == 'ps':
  server.join()
elif FLAGS.job_name == "worker":  
  # Client
  with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,cluster=cluster)):
    # Build model...
    loss = ...
    train_op = ...

  with tf.train.MonitoredTrainingSession(master="/job:worker/task:0",is_chief=(FLAGS.task_index == 0),checkpoint_dir="/tmp/train_logs") as mon_sess:
    while not mon_sess.should_stop():
      mon_sess.run(train_op)

Here is a PyTorch distributed implementation, it is written in the forward-propagation function. In TensorFlow, where should I write the code?

import torch.distributed as dist
# here to reset unused embeddings
if self.restart_unused_codes:
           # generate new embeddings for unused embeddings
            if n_vectors < n_embed:
                vectors = self._tile_with_noise(vectors, n_embed)
            n_vectors = vectors.shape[0]
            _vectors_random = vectors[torch.randperm(n_vectors, device=vectors.device)][:n_embed]
            
            # Broadcast the new embedding to each node in the distributed system
            if dist.is_initialized():
                dist.broadcast(_vectors_random, 0)
        
            # Assign new embeddings to those unused embeddings.
            usage = (self.cluster_size_ema.view(-1, 1) >= 1).float()
            self.embed_ema.mul_(usage).add_(_vectors_random * (1-usage))
            self.cluster_size_ema.mul_(usage.view(-1))
            self.cluster_size_ema.add_(torch.ones_like(self.cluster_size_ema) * (1-usage).view(-1))