I trained a model with the following architecture:

bert_config = BertConfig.from_pretrained(MODEL_NAME)
bert_config.output_hidden_states = True
backbone = TFAutoModelForSequenceClassification.from_pretrained(MODEL_NAME,config=bert_config)

input_ids = tf.keras.layers.Input(shape=(MAX_LENGTH,), name='input_ids', dtype='int32')
features = backbone(input_ids)[1][-1]
pooling =  tf.keras.layers.GlobalAveragePooling1D()(features)
dense = tf.keras.layers.Dense(len(label2id), name='output',activation=tf.nn.softmax)(pooling)
model = tf.keras.Model(inputs=[input_ids], outputs = [dense])

and saved model in different ways. The first one is


and the second one is


So, I am able to load my model (model.load_weights()) from both of these options without any error. Moreover, inference is fine. In short, everything works as I expect.

But if I start a new session and load my model again then inference is bad, like model has random weights instead of my saved weights.

I was trying other options of saving models as well, but they do not work also. Probably there is a special way to save a model with transformer layer?

Thanks in advance!

Do you have experienced the same problem with


Unfortunately, yes

Have you tried to post this on the Hugginface forum?
As it seems that you are using one of their classes TFAutoModelForSequenceClassification

I suppose your post is this one:

Have you tried if it works with the tf.function and model signature approach:

See also:

infact your code is ok just follow my first post and you will get reproducible results anytime you run on colab, even if your session expired, even you dont want function you can just do it this way just after importing these two modules:

from tensorflow.python.framework import ops
import tensorflow as tf


see the problem is that tensorflow reinitialize variables when your session expires, so to solve this problem, you have to manually set your own random seed after importing tensorflow and numpy.

create a function this way:

SEED=42 #choose any seed of your choice
def reproducibleResult(seed:int):

#call your function and problem solved :slight_smile: