Issue with training a MLP with DenseFlipout layers

Hi, i would like for an advice regarding a problem that i am having with a mlp with DenseFlipout layers.
I have been trying to train the neural network for a regression task and i am using denseFlipout layers with mostly the default settings

def create_flipout_bnn_model(train_size):
  def normal_sp(params): 
      return tfd.Normal(loc=params[:,0:1], scale=1e-3 + tf.math.softplus(0.05 * params[:,1:2]))

  kernel_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)
  bias_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (train_size)


  inputs = Input(shape=(1,),name="input layer")


  hidden = tfp.layers.DenseFlipout(50,
                           kernel_divergence_fn=kernel_divergence_fn,
                           activation="relu",name="DenseFlipout_layer_1")(inputs)
  hidden = tfp.layers.DenseFlipout(50,
                           kernel_divergence_fn=kernel_divergence_fn,
                           activation="relu",name="DenseFlipout_layer_2")(hidden)
  hidden = tfp.layers.DenseFlipout(50,
                           kernel_divergence_fn=kernel_divergence_fn,
                           activation="relu",name="DenseFlipout_layer_3")(hidden)                    
  params = tfp.layers.DenseFlipout(2,
                           kernel_divergence_fn=kernel_divergence_fn,
                           name="DenseFlipout_layer_5")(hidden)
  dist = tfp.layers.DistributionLambda(normal_sp,name = 'normal_sp')(params) 

  model = Model(inputs=inputs, outputs=dist)

 
  return model

 
flipout_BNN = create_flipout_bnn_model(train_size=train_size)
flipout_BNN.compile(optimizer=Adam(learning_rate=0.002 ),
                  loss=NLL,metrics= [tf.keras.metrics.RootMeanSquaredError()]
                 ) 
 
flipout_BNN.summary()

history_flipout_BNN = flipout_BNN.fit(X_train, y_train, epochs=50000, verbose=0, batch_size=batch_size,validation_data=(X_val,y_val) )

but the plot of the loss function keep showing spikes no matter the number of epochs. what can i do to avoid this issue? i think it’s related to the fact that the weights are sampled from a distribution but still…shouldnt those spikes disappear?
image

1 Like