INVALID_ARGUMENT: Invalid reduction arguments: Axes contains duplicate dimension: 1

Model looks like this:

max_length = 100

# Define the CNN model
inputs = tf.keras.Input(shape=(max_length,), dtype=tf.int32)
input_masks = tf.keras.Input(shape=(max_length,), dtype=tf.int32)

bert_model = TFAutoModel.from_pretrained('dbmdz/bert-base-german-uncased')
bert_model.trainable = False
bert_output = bert_model(inputs, attention_mask=input_masks,)[0]

# Apply a 1D convolution to the BERT embeddings
conv1d = tf.keras.layers.Conv1D(filters=128, kernel_size=3, padding='valid', 
                                activation='relu')(bert_output)

pool1d = tf.keras.layers.GlobalMaxPooling1D()(conv1d)
dropout1 = tf.keras.layers.Dropout(0.2)(pool1d, training=True)

# Flatten the pooled features
flatten = tf.keras.layers.Flatten()(dropout1)

# Apply a dense layer to classify the features
dense = tf.keras.layers.Dense(units=128, activation='relu')(flatten)
dropout4 = tf.keras.layers.Dropout(0.5)(dense, training=True)
outputs = tf.keras.layers.Dense(units=5, activation='softmax')(dropout4)

model = tf.keras.models.Model(inputs=[inputs, input_masks], outputs=outputs)

# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'], run_eagerly=True)

Training looks like this:

history = model.fit(bert_encode_generator(train_df['tweet'].values, train_df['label'].values, tokenizer, 32, 100), validation_data=(bert_encode_generator(validation_df['tweet'].values, validation_df['label'].values, tokenizer, 32, 100)), callbacks=[early_stopping], epochs=args.epochs, steps_per_epoch=len(train_df)//32, validation_steps=len(validation_df)//32)

Bert_encode_generator:

def bert_encode_generator(texts, labels, tokenizer, batch_size, max_length=100):
    while True:
        num_texts = len(texts)
        num_batches = math.ceil(len(texts) / batch_size)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, num_texts)

            batch_texts = texts[start_idx:end_idx]
            batch_labels = labels[start_idx:end_idx]

            all_tokens = []
            all_masks = []

            for text in batch_texts:
                text = tokenizer.encode_plus(text, add_special_tokens=False, max_length=max_length,
                                             return_attention_mask=True, return_tensors='tf',
                                             padding='max_length', truncation=True)
                input_ids = text['input_ids'][0]
                attention_mask = text['attention_mask'][0]
                all_tokens.append(input_ids)
                all_masks.append(attention_mask)

            yield [tf.convert_to_tensor(all_tokens), tf.convert_to_tensor(all_masks)], batch_labels

The error in the title occured twice when using the model.fit method. The first time after 200 steps, the second after 1420 steps. I don’t really know what’s going on here and the error message isn’t helping much either. However, the problem probably stems from the BERT layer, because the following was logged as well:

Exception encountered when calling layer 'embeddings' (type TFBertEmbeddings).

Call arguments received by layer 'embeddings' (type TFBertEmbeddings):
  • input_ids=tf.Tensor(shape=(32, 100), dtype=int32)
  • position_ids=None
  • token_type_ids=tf.Tensor(shape=(32, 100), dtype=int32)
  • inputs_embeds=None
  • past_key_values_length=0
  • training=True

If anyone encountered this error before or knows something about the origin of it, please let me know