Unable to use multiple GPUs with tf.keras.layers.Normalization

I’m using tensorflow 2.7.0 with python 3.8.10 (linux env).

This error occurs when i try to add a normalization layer to my model. Removing the normalization layer or limiting the amount of GPUs to 1 or 2 makes the code run normally otherwise when i try to use 4 gpus the error below occurs.

Traceback (most recent call last):
  File "flat_resnet50.py", line 263, in <module>
    history = train_model(train_path, validation_path, epochs, steps_per_epoch, resnet_50V2)
  File "flat_resnet50.py", line 82, in train_model
    history = model.fit(x=train_dataset,
  File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/execute.py", line 58, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 5 root error(s) found.
  (0) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
	 [[div_no_nan_1/_181]]
  (1) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
	 [[div_no_nan_1/ReadVariableOp_2/_152]]
  (2) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
	 [[cond/output/_19/_108]]
  (3) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
	 [[GroupCrossDeviceControlEdges_2/NoOp/_225]]
  (4) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_119674]

Errors may have originated from an input operation.
Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Function call stack:
test_function -> test_function -> test_function -> test_function -> test_function

Error in sys.excepthook:
Traceback (most recent call last):
  File "/usr/lib/python3/dist-packages/apport_python_hook.py", line 72, in apport_excepthook
    from apport.fileutils import likely_packaged, get_recent_crashes
  File "/usr/lib/python3/dist-packages/apport/__init__.py", line 5, in <module>
    from apport.report import Report
  File "/usr/lib/python3/dist-packages/apport/report.py", line 32, in <module>
    import apport.fileutils
  File "/usr/lib/python3/dist-packages/apport/fileutils.py", line 27, in <module>
    from apport.packaging_impl import impl as packaging
  File "/usr/lib/python3/dist-packages/apport/packaging_impl.py", line 23, in <module>
    import apt
  File "/usr/lib/python3/dist-packages/apt/__init__.py", line 35, in <module>
    apt_pkg.init_config()
apt_pkg.Error: E:Syntax error /etc/apt/apt.conf.d/20auto-upgrades:6: Extra junk at end of file

Original exception was:
Traceback (most recent call last):
  File "flat_resnet50.py", line 263, in <module>
    history = train_model(train_path, validation_path, epochs, steps_per_epoch, resnet_50V2)
  File "flat_resnet50.py", line 82, in train_model
    history = model.fit(x=train_dataset,
  File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/execute.py", line 58, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 5 root error(s) found.
  (0) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
	 [[div_no_nan_1/_181]]
  (1) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
	 [[div_no_nan_1/ReadVariableOp_2/_152]]
  (2) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
	 [[cond/output/_19/_108]]
  (3) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
	 [[GroupCrossDeviceControlEdges_2/NoOp/_225]]
  (4) INVALID_ARGUMENT:  required broadcastable shapes
	 [[node replica_3/model_2/normalization/sub
 (defined at /usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py:257)
]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_119674]

Errors may have originated from an input operation.
Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>> 

Input Source operations connected to node replica_3/model_2/normalization/sub:
In[0] cond/Identity_3 (defined at /usr/local/lib/python3.8/dist-packages/keras/engine/training.py:1355)	
In[1] model_2/normalization/sub/y:

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
>>>     self._bootstrap_inner()
>>> 
>>>   File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
>>>     self.run()
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1349, in run_step
>>>     outputs = model.test_step(data)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1303, in test_step
>>>     y_pred = self(x, training=False)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 451, in call
>>>     return self._run_internal_graph(
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/functional.py", line 589, in _run_internal_graph
>>>     outputs = node.layer(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1083, in __call__
>>>     outputs = call_fn(inputs, *args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.8/dist-packages/keras/layers/preprocessing/normalization.py", line 257, in call
>>>     return ((inputs - self.mean) /
>>>