Help converting TF1 code to TF2

Hello, I am working on a code that was written in TF1. I am studying the code and try to convert it into TF2. Today I struggle to know what is the equivalent of zero_state method from tf.compat.v1.nn.rnn_cell.RNNCell. I would like to be able to compute this zero-filled state tensor, but with native TF2 code (using keras).

Thanks and regards.

Hi @TOP1 ,

In TensorFlow 2.x, the zero_state method is replaced by the get_initial_state method.

To use get_initial_state , you need to first create an instance of the RNN cell, and then call get_initial_state method on the cell with the batch size and dtype as arguments.

import tensorflow as tf
cell = tf.keras.layers.LSTMCell(128)# 128 units
batch_size = #your choice
initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)

This will create an instance of the LSTM cell with 128 units and then use the get_initial_state method to get the initial state tensor with the given batch size and data type.

I hope this will helps you and please go through this tutorial Migrate from TensorFlow 1.x to TensorFlow 2 for more details.


1 Like

Hi @Laxma_Reddy_Patlolla, thanks for your response. What I did is the following:

zero_state = tf.zeros(shape=[batch_size, self.cell.state_size], dtype=dtype)

Because, for me it is not clear how to make sure it is a tensor full of zeros in the RNN documentation.

Anyway I checked and this give same results as you. Also thanks for the link about TF1 to TF2 migration. I am a novice in TensorFlow, so beginning by trying to translate TF1 to TF2 is not the optimal way of learning it ! But I have to do it anyway, at least it is a way forcing me to go deep in the documentation.
Best regards.