Running LSTMs and Conv1Ds over 1D series but with 4 dims - Is there a better way?

Hi! I have worked building NLP models with tensorflow with/wout keras for some time. I faced an issue with keras and tensorflow while trying to do the following:

input = tf.keras.layers.Input(shape=(MAX_WORDS, MAX_CHARS))
char_embedded = tf.keras.layers.Embedding(...)(input)

So now char embedding has dimension (dynamic batch size, MAX_WORDS, MAX_CHARS, embedding_dim), and doing any of the following throws an error:


This error is caused because the tensor is 4 dimensional and both Conv1D or RNNs expect a 3 dimensional tensor. This is used in NER tagging or NLP for unknown words when we then concatenate the resulting char digest per word as a part of the word embeddings.
I have a wordaround for this:

lstm = tf.keras.layers.LSTM(...)
digests = []
for text in tf.unstack(char_embedded):
digests = tf.stack(digests)

So now i have in digests what i wanted with dimension (FIXED_BATCH_SIZE, MAX_WORDS, output_dim).
In order for unstack to work it needs to know the batch size while building the graph, so im forced to make it constant:

input = tf.keras.layers.Input(shape=(MAX_WORDS, MAX_CHARS), batch_size=FIXED_BATCH_SIZE)

This whole solution has several disadvantages:

  • It’s messy in code
  • The TF graph looks awful in tensorboard and is slower to build so i guess it’s also a mess there
  • Forces me to train and predict with batch_size divisible amount of data, which is annoying because i have to constantly padding and unpadding things
  • If i have to predict for 1 item i have to pay the time cost of predicting for FIXED_BATCH_SIZE padding it

The last item has a final workaround that is training with the fixed batch size graph and then with those layers create the graph fixed with a batch size of 1, but this adds complexity to the code.

I have been using this workaround for 2 years now, and it works, and had no problems with it others than the listed ones, buut, i always thought it has to be a better way and probably much much strightforward and simple. So, is there?


