Setting batch size of tf.data.dataset using keras.tuner

Hello everyone,

I have a very specific question regarding my implementation to set the batchsize of a tf.data.dataset when using Keras Tuner’s Hyperband. To achieve this I subclassed the Hyperband class with my own class:

class MyHyperModel(kt.HyperModel):
    def __init__(self):
        super().__init__()

    def build(self, hyper_params):
        lr = hyper_params.Float('lr', 0.0001, 1.0, sampling='log')
        optimizer = hyper_params.Choice('optimizer', ['adam', 'sgd', 'rmsprop'])
        conv_kernel_initializer = hyper_params.Choice('conv_kernel_initializer', ['glorot_uniform', 'glorot_normal', 'he_uniform', 'he_normal', 'None'])  
        dense_kernel_initializer = hyper_params.Choice('dense_kernel_initializer', ['glorot_uniform', 'glorot_normal', 'he_uniform', 'he_normal', 'None'])
        cnn_activation = hyper_params.Choice('cnn_activation', ['relu', 'tanh', 'sigmoid', 'leakyrelu'])
        dense_activation = hyper_params.Choice('dense_activation', ['relu', 'tanh', 'sigmoid', 'leakyrelu'])   
        num_fc_layers = hyper_params.Int('num_fc_layers', 1, 5)
        fc_width = hyper_params.Int('fc_width', 1, 100)
        dropout = hyper_params.Float('dropout', 0.0, 0.5)
        batch_size = hyper_params.Choice('batch_size', [32, 64])

        hyper_params={
            'lr' : lr,
            'optimizer' : optimizer,
            'conv_kernel_initializer' : conv_kernel_initializer,
            'dense_kernel_initializer' : dense_kernel_initializer,
            'cnn_activation' : cnn_activation,
            'dense_activation' : dense_activation,
            'num_fc_layers' : num_fc_layers,
            'fc_width' : fc_width,
            'dropout' : dropout,
        }
        return get_cnn_stk_simplified(hyper_params)

    def fit(self, hyper_parameters, model, *args, **kwargs):
        #Batch the tf,dataset passes as 'x' in the kwargs
        dataset = kwargs.pop('x').batch(hyper_parameters.get('batch_size'))
        valdiadtion_dataset = kwargs.pop('validation_data').batch(hyper_parameters.get('batch_size'))
        kwargs['x'] = dataset
        kwargs['validation_data'] = valdiadtion_dataset
        #Run the fit with the batched data
        return super().fit(hyper_parameters, model, *args, **kwargs )

And here the code calling my subclass:

tuner = kt.Hyperband(
            hypermodel=MyHyperModel(),
            objective='val_loss',
            max_epochs=10,
            directory=args.log_dir,
            project_name='stk_cnn_hyperband',
            overwrite=True)
        tuner.search_space_summary()
        tuner.search(x=train_data, epochs=EPOCHS, validation_data=validation_data, verbose=args.train_verbose, callbacks=[keras.callbacks.TensorBoard(log_dir=args.log_dir+tuner.project_name, histogram_freq=1, write_graph=True, update_freq='batch', profile_batch=1)])

This works as intended when I run it locally on my laptop using just the CPU.
Whereas if I run it on a slurm cluster with access to a GPU keras-tuner does not work. I get the following Error:

  File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras_tuner/en
gine/base_tuner.py", line 144, in search
    self.run_trial(trial, *fit_args, **fit_kwargs)
  File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras_tuner/tuners/hyperband.py", line 370, in run_trial
    super(Hyperband, self).run_trial(trial, *fit_args, **fit_kwargs)
  File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras_tuner/engine/multi_execution_tuner.py", line 90, in run_trial
    history = self._build_and_fit_model(trial, fit_args, copied_fit_kwargs)
  File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras_tuner/engine/tuner.py", line 147, in _build_and_fit_model
    return model.fit(*fit_args, **fit_kwargs)
  File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/scratch/__autograph_generated_file20gmt774.py", line 15, in tf__train_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
ValueError: in user code:

    File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras/engine/training.py", line 1284, in train_function  *
        return step_function(self, iterator)
    File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras/engine/training.py", line 1268, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras/engine/training.py", line 1249, in run_step  **
        outputs = model.train_step(data)
    File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras/engine/training.py", line 1050, in train_step
        y_pred = self(x, training=True)
    File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/cvmfs/sft.cern.ch/lcg/views/LCG_104cuda/x86_64-centos8-gcc11-opt/lib/python3.9/site-packages/keras/engine/input_spec.py", line 298, in assert_input_compatibility
        raise ValueError(

    ValueError: Input 0 of layer "stk_cnn_model_simple" is incompatible with the layer: expected shape=(None, 2, 400, 400), found shape=(2, 400, 400)

From the error message it seams that my subclass’s fit function is not called at all…

Has anyone encountered a similar issue?