UpdatePruningStep() throwing an error " ValueError: Error processing property '_dropout_mask_cache' of <ContextValueCache at 0x1c2604d02e0>"

I am trying to prune my pre-trained model and for that It is mandatory to use UpdatePruningStep() in the callbacks while fitting the model. When I do so, I am getting the error as follows -
ValueError: Error processing property ‘_dropout_mask_cache’ of <ContextValueCache at 0x1c2604d02e0>

code for pruning is as follows -

from tensorflow_model_optimization.sparsity import keras as sparsity

num_train_samples=X_train.shape[0]
batch_size=128

epochs = 4
end_step = np.ceil(1*num_train_samples/batch_size).astype(np.int32)*epochs
print(end_step)

new_pruning_params = {
‘pruning_schedule’ : sparsity.PolynomialDecay(initial_sparsity=0.50,
final_sparsity=0.90,
begin_step = 0,
end_step = end_step,
frequency=100)
}

new_pruned_model = sparsity.prune_low_magnitude(loaded_model, ** new_pruning_params)
new_pruned_model.summary()

new_pruned_model.compile(loss = tf.keras.losses.categorical_crossentropy,
optimizer=‘adam’,
metrics=[‘accuracy’])

logdir = tempfile.mkdtemp()
callbacks = [sparsity.UpdatePruningStep(), sparsity.PruningSummaries(log_dir=logdir)]

keras.callbacks.EarlyStopping(monitor=‘val_loss’, patience=10)]

new_pruned_model.fit(X_train, Y_train.values, batch_size=batch_size, epochs=epochs,verbose=1,
validation_data = (X_test,Y_test.values), callbacks=callbacks)

score= new_pruned_model.evaluate(X_test,Y_test.values, verbose=0)
print('Test Loss : ’ , score[0])
print('Test Accuraccy: ', score[1])

This seems similar to Error processing property '_dropout_mask_cache' when using PrunableLayer with DropoutRNNCellMixin · Issue #753 · tensorflow/model-optimization · GitHub

Does the fix there work for you?