Update all worker replicas from one worker using MultiWorkerMirroredStrategy

I’m trying to implement a distributed reinforcement learning system using tf.distribute.MultiWorkerMirroredStrategy. I’m not exactly sure how the strategy should work. My requirement is that I need to update the replicas on all workers from a single training Process. The following simplified script launches a separate Process. Both main process and worker process have the replicas in sync in step 1.

In step 2, I intend to update the model weights in main process only and hope that the worker process model’s weights get updated. Note that the grads calculated in the main process are not distributed. I hoped the optimizer should automatically apply the grads to all replicas.

import time
import tensorflow as tf
import os
import json
import multiprocessing

tf_config = {
    'cluster': {
        'worker': ['localhost:2222', 'localhost:2223'],
    },
    'task': {
        'type': 'worker',
        'index': 0,
    }
}

def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(2, input_shape=(1,))
    ])
    return model

def worker_proc():
    worker_config = tf_config.copy()
    worker_config['task']['type'] = 'worker'
    worker_config['task']['index'] = 1
    os.environ["TF_CONFIG"] = json.dumps(worker_config)
    strategy = tf.distribute.MultiWorkerMirroredStrategy()

    with strategy.scope():
        model = create_model()
        print("worker replica - step 1: ", model.get_weights())

    time.sleep(10)
    print("worker replica - step 2: ", model.get_weights())


def main():
    os.environ["TF_CONFIG"] = json.dumps(tf_config)
    worker = multiprocessing.Process(target=worker_proc)
    worker.start()

    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    
    with strategy.scope():
        model = create_model()
        print("main replica - step 1: ", model.get_weights())
        optimizer = tf.keras.optimizers.Adam()

    with tf.GradientTape() as tape:
        loss  = tf.keras.losses.MSE(model(tf.constant([[1.]])), tf.constant([[1.]]))
        grads = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    print("main replica - step 2: ", model.get_weights())

    worker.join()


if __name__ == '__main__':
    main()

I tried creating the optimizer inside the strategy scope, however, it fails when applying the gradients:

Collective ops is aborted by: cluster check alive failed, /job:worker/replica:0/task:1 is down

I also tried to create the optimizer outside the strategy scope, thinking of using strategy.extended.update() once the optimizer updates the weights locally. But then I get this:

ValueError: Trying to create optimizer slot variable under the scope for tf.distribute.Strategy (<tensorflow.python.distribute.distribute_lib._DefaultDistributionStrategy object at 0x7f5bd555b940>), which is different from the scope used for the original variable

What am I missing here?