Question: Multi-worker training with keras

I trying to pratice the example of Multi-worker training with Keras in colab:

import json
import os
import sys
os.environ[“CUDA_VISIBLE_DEVICES”] = “-1”
os.environ.pop(‘TF_CONFIG’, None) ```
if ‘.’ not in sys.path:
sys.path.insert(0, ‘.’)
!pip install tf-nightly
import tensorflow as tf

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset =
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset

def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation=‘relu’),
tf.keras.layers.Dense(128, activation=‘relu’),
return model

import mnist_setupbatch_size = 64single_worker_dataset = mnist_setup.mnist_dataset(batch_size)single_worker_model = mnist_setup.build_and_compile_cnn_model(), epochs=3, steps_per_epoch=70)

tf_config = { ‘cluster’: { ‘worker’: [‘localhost:12345’, ‘localhost:23456’] }, ‘task’: {‘type’: ‘worker’, ‘index’: 0}}


os.environ[‘GREETINGS’] = ‘Hello TensorFlow!’

!echo ${GREETINGS}

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope(): # Model building/compiling need to be within strategy.scope(). multi_worker_model = mnist_setup.build_and_compile_cnn_model()


import os
import json

import tensorflow as tf
import mnist_setup

per_worker_batch_size = 64
tf_config = json.loads(os.environ[‘TF_CONFIG’])
num_workers = len(tf_config[‘cluster’][‘worker’])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)

with strategy.scope():
multi_worker_model = mnist_setup.build_and_compile_cnn_model(), epochs=3, steps_per_epoch=70)

!ls *.py

os.environ[‘TF_CONFIG’] = json.dumps(tf_config)


!python &> job_0.log

The last step(!python &> job_0.log), it run over and over, no stop.
Please tell me how to solve it, Thank you very much.

Hi @hsiaokai0309, If you are using a single worker you have run up to the above code. If you are running on multiple nodes you have to remove the above code and run the remaining code. Thank you.