Setting batch size of using keras.tuner

Hello everyone,

I have a very specific question regarding my implementation to set the batchsize of a when using Keras Tuner’s Hyperband. To achieve this I subclassed the Hyperband class with my own class:

class MyHyperModel(kt.HyperModel):
    def __init__(self):

    def build(self, hyper_params):
        lr = hyper_params.Float('lr', 0.0001, 1.0, sampling='log')
        optimizer = hyper_params.Choice('optimizer', ['adam', 'sgd', 'rmsprop'])
        conv_kernel_initializer = hyper_params.Choice('conv_kernel_initializer', ['glorot_uniform', 'glorot_normal', 'he_uniform', 'he_normal', 'None'])  
        dense_kernel_initializer = hyper_params.Choice('dense_kernel_initializer', ['glorot_uniform', 'glorot_normal', 'he_uniform', 'he_normal', 'None'])
        cnn_activation = hyper_params.Choice('cnn_activation', ['relu', 'tanh', 'sigmoid', 'leakyrelu'])
        dense_activation = hyper_params.Choice('dense_activation', ['relu', 'tanh', 'sigmoid', 'leakyrelu'])   
        num_fc_layers = hyper_params.Int('num_fc_layers', 1, 5)
        fc_width = hyper_params.Int('fc_width', 1, 100)
        dropout = hyper_params.Float('dropout', 0.0, 0.5)
        batch_size = hyper_params.Choice('batch_size', [32, 64])

            'lr' : lr,
            'optimizer' : optimizer,
            'conv_kernel_initializer' : conv_kernel_initializer,
            'dense_kernel_initializer' : dense_kernel_initializer,
            'cnn_activation' : cnn_activation,
            'dense_activation' : dense_activation,
            'num_fc_layers' : num_fc_layers,
            'fc_width' : fc_width,
            'dropout' : dropout,
        return get_cnn_stk_simplified(hyper_params)

    def fit(self, hyper_parameters, model, *args, **kwargs):
        #Batch the tf,dataset passes as 'x' in the kwargs
        dataset = kwargs.pop('x').batch(hyper_parameters.get('batch_size'))
        valdiadtion_dataset = kwargs.pop('validation_data').batch(hyper_parameters.get('batch_size'))
        kwargs['x'] = dataset
        kwargs['validation_data'] = valdiadtion_dataset
        #Run the fit with the batched data
        return super().fit(hyper_parameters, model, *args, **kwargs )

And here the code calling my subclass:

tuner = kt.Hyperband(
        tuner.search_space_summary(), epochs=EPOCHS, validation_data=validation_data, verbose=args.train_verbose, callbacks=[keras.callbacks.TensorBoard(log_dir=args.log_dir+tuner.project_name, histogram_freq=1, write_graph=True, update_freq='batch', profile_batch=1)])

This works as intended when I run it locally on my laptop using just the CPU.
Whereas if I run it on a slurm cluster with access to a GPU keras-tuner does not work. I get the following Error:

  File "/cvmfs/
gine/", line 144, in search
    self.run_trial(trial, *fit_args, **fit_kwargs)
  File "/cvmfs/", line 370, in run_trial
    super(Hyperband, self).run_trial(trial, *fit_args, **fit_kwargs)
  File "/cvmfs/", line 90, in run_trial
    history = self._build_and_fit_model(trial, fit_args, copied_fit_kwargs)
  File "/cvmfs/", line 147, in _build_and_fit_model
    return*fit_args, **fit_kwargs)
  File "/cvmfs/", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/scratch/", line 15, in tf__train_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
ValueError: in user code:

    File "/cvmfs/", line 1284, in train_function  *
        return step_function(self, iterator)
    File "/cvmfs/", line 1268, in step_function  **
        outputs =, args=(data,))
    File "/cvmfs/", line 1249, in run_step  **
        outputs = model.train_step(data)
    File "/cvmfs/", line 1050, in train_step
        y_pred = self(x, training=True)
    File "/cvmfs/", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/cvmfs/", line 298, in assert_input_compatibility
        raise ValueError(

    ValueError: Input 0 of layer "stk_cnn_model_simple" is incompatible with the layer: expected shape=(None, 2, 400, 400), found shape=(2, 400, 400)

From the error message it seams that my subclass’s fit function is not called at all…

Has anyone encountered a similar issue?