LSTM in JAX & Flax (Complete example with code and notebook)