Recommended model architecture for text classification

Hello everyone,

I want to train a model for text classification regarding tweets of german politicians classified to their respective party. For this I gathered a dataset of 4.8 million tweets as training data. Additionally, I want to use BERT to generate embeddings that get fed into the CNN.

I then wanted to incrementally train on splits of the large dataset, so I can monitor the performance of the model better. However, after one epoch of training on ~50000 samples of data, the model already started to overfit, which lead me to believe that the model architecture is too complex for the task at hand because I’m using three consecutive chains of convolutional, pooling and dropout layer. So now my question to the community is, what kind of model architecture would you recommend for a task like this? Any feedback and experiences are welcome :slight_smile:

Hi @Viindo. There can multiple reasons why this overfit. And it is very difficult to help based on the information you provided in your original post. Maybe you can be more specific e.g. about your CNN model?

The BERT you are using is frozen, right? And why do you even have a CNN after BERT? Can’t you just train on the CLS embedding?

PS You can’t overtrain on a single epoch. That shouldn’t be possible.

Yes you are right, it does not overfit. What my issue was, is that I didn’t save the model accordingly, which resulted in a loss of the trained weights, that lead me to believe, the model doesn’t perform well on unseen data. Regarding BERT, it is frozen. Firstly I wanted to use word embeddings such as Word2Vec or fastText, however, I learned about transformer models that can create embeddings as well, which is why i am using BERT to create embeddings to gain contextual and syntactic information about the data I am using. Afterwards, the goal of using a CNN, is to extract more features and reduce dimensionality. I basically substituted a different word embedding algorithm with BERT.

The CLS tag is a fixed length embedding that represents the whole sentence. You should be able to use that instead of using convolutions to get to a fixed length.