TFA BeamSearchDecoder Clarification Request

Hi,

If the question seems to dumb, it is because I am new to TensorFlow.
I was implementing a toy endocer-decoder problem using TensorFlow 2’s TFA seq2seq implementation.
The API was clearly understandable until I wanted to change my BasicDecoder with BeamSearchDecoder.
My question is regarding start_tokens and end_token arguments’ initialization of BeamSearchDecoder.

Here is a copy of the implementation, any help is appreciated.


tf.keras.backend.clear_session()
tf.random.set_seed(42)

enc_vocab_size = len(train_vocab) + 1
dec_vocab_size = len(target_vocab) + 1
embed_size = 10


import tensorflow_addons as tfa

encoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)
decoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)
sequence_lengths = keras.layers.Input(shape=[], dtype=np.int32)


encoder_embeddings = keras.layers.Embedding(enc_vocab_size, embed_size)(encoder_inputs)
encoder = keras.layers.LSTM(512, return_state = True)
encoder_outputs, state_h, state_c = encoder(encoder_embeddings)
encoder_state = [state_h, state_c]


sampler = tfa.seq2seq.sampler.TrainingSampler()


decoder_embeddings = keras.layers.Embedding(dec_vocab_size, embed_size)(decoder_inputs)
decoder_cell = keras.layers.LSTMCell(512)
output_layer = keras.layers.Dense(dec_vocab_size)



beam_width = 10
start_tokens = tf.zeros([32], tf.dtypes.int32)
end_tokens = tf.constant([1], tf.dtypes.int32)
decoder = tfa.seq2seq.beam_search_decoder.BeamSearchDecoder(cell = decoder_cell, beam_width = beam_width, output_layer = output_layer)
decoder_initial_state = tfa.seq2seq.beam_search_decoder.tile_batch(encoder_state, multiplier = beam_width)
outputs, _, _ = decoder(decoder_embeddings, start_tokens = start_tokens, end_token = 0, initial_state = decoder_initial_state)
Y_proba = tf.nn.softmax(outputs.rnn_output)


model = keras.models.Model(inputs = [encoder_inputs, decoder_inputs], outputs = [Y_proba])
model.compile(loss="sparse_categorical_crossentropy", optimizer = 'adam', metrics = ['accuracy'])

Error trace:



---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-99-8287ffcfd4fa> in <module>()
     34 decoder = tfa.seq2seq.beam_search_decoder.BeamSearchDecoder(cell = decoder_cell, beam_width = beam_width, output_layer = output_layer)
     35 decoder_initial_state = tfa.seq2seq.beam_search_decoder.tile_batch(encoder_state, multiplier = beam_width)
---> 36 outputs, _, _ = decoder(decoder_embeddings, start_tokens = start_tokens, end_token = 0, initial_state = decoder_initial_state)
     37 Y_proba = tf.nn.softmax(outputs.rnn_output)
     38 

1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    690       except Exception as e:  # pylint:disable=broad-except
    691         if hasattr(e, 'ag_error_metadata'):
--> 692           raise e.ag_error_metadata.to_exception(e)
    693         else:
    694           raise

ValueError: Exception encountered when calling layer "beam_search_decoder" (type BeamSearchDecoder).

in user code:

    File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/beam_search_decoder.py", line 941, in call  *
        self,
    File "/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py", line 262, in wrapper  *
        retval = func(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/decoder.py", line 430, in body  *
        (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/beam_search_decoder.py", line 705, in step  *
        cell_outputs, next_cell_state = self._cell(
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler  **
        raise e.with_traceback(filtered_tb) from None

    ValueError: Exception encountered when calling layer "lstm_cell_1" (type LSTMCell).
    
    Dimensions must be equal, but are 80 and 320 for '{{node beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/mul}} = Mul[T=DT_FLOAT](beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/Sigmoid_1, beam_search_decoder/decoder/while/BeamSearchDecoderStep/Reshape_2)' with input shapes: [320,80,2048], [320,512].
    
    Call arguments received:
      • inputs=tf.Tensor(shape=(320, None, 10), dtype=float32)
      • states=ListWrapper(['tf.Tensor(shape=(320, 512), dtype=float32)', 'tf.Tensor(shape=(320, 512), dtype=float32)'])
      • training=None


Call arguments received:
  • embedding=tf.Tensor(shape=(None, None, 10), dtype=float32)
  • start_tokens=tf.Tensor(shape=(32,), dtype=int32)
  • end_token=0
  • initial_state=['tf.Tensor(shape=(None, 512), dtype=float32)', 'tf.Tensor(shape=(None, 512), dtype=float32)']
  • training=None
  • kwargs=<class 'inspect._empty'>

Did you take a look a this tutorial:

Hey Ethen,

Here I implemented your question from scratch and tried to keep it as minimalistic as possible for sake of simplicity, you can access it from here: [CLARIFICATION REQUEST] Chapter 16 Beam Search · Issue #541 · ageron/handson-ml2 · GitHub

Please refer to this link for full implementation.
Of course, you can extend it by considering different types of scheduling and attention mechanism in your future implementation.

Beamsearch is used once we are done with training.
Here I am again adding what happens once you trained the model.

For implementation of the encoder-decoder part, see the provided GitHub.

Implementing Beamserach

def beam_search_inferance_model(beam_width):
  batch_size = tf.shape(encoder_input)[:1]
  max_output_length = Y_train.shape[1]
  start_tokens = tf.fill(dims = batch_size, value = sos_id)
  decoder_initial_state = tfa.seq2seq.tile_batch(encoder_state_HC, multiplier = beam_width)
  beam_search_inference = tfa.seq2seq.BeamSearchDecoder(cell = LSTMCell, beam_width = beam_width, output_layer = output_layer, maximum_iterations = max_output_length)
  outputs, _, _ = beam_search_inference(decoder_embd_layer.variables, start_tokens = start_tokens, end_token = 0, initial_state = decoder_initial_state)
  final_outputs = tf.transpose(outputs.predicted_ids, perm = (0,2,1))
  beam_scores = tf.transpose(outputs.beam_search_decoder_output.scores, perm = (0,2,1))
  return keras.Model(inputs = [encoder_input], outputs = [final_outputs, beam_scores])

beam_serach_inferance_model = beam_serach_inferance_model(3)

Utility function
I copied this function from TFA’s API tutorial and adapted it!!!

def beam_translate(sentence):
  X = prepare_date_strs_padded(sentence)
  result, beam_scores = beam_search_inferance_model.predict(X)
  for beam, score in zip(result, beam_scores):
    output = ids_to_date_strs(beam)
    beam_score = [a.sum() for a in score]
    print('Input: %s' % sentence)
    print('-----' * 12)
    for i in range(len(output)):
      print('{} Predicted translation: {}  {}'.format(i + 1, output[i], beam_score[i]))
    print('\n')

Output

beam_translate(["July 14, 1789", "September 01, 2020"])

Input: ['July 14, 1789', 'September 01, 2020']
------------------------------------------------------------
1 Predicted translation: 2288-01-11  -83.7786865234375
2 Predicted translation: 2288-01-10  -83.90345764160156
3 Predicted translation: 2288-01-21  -84.30797576904297


Input: ['July 14, 1789', 'September 01, 2020']
------------------------------------------------------------
1 Predicted translation: 2221-02-26  -79.02340698242188
2 Predicted translation: 2222-02-26  -79.29275512695312
3 Predicted translation: 2221-02-21  -80.06587982177734

I hope this helps!

Cheers,
Kasra